|
|
import gradio as gr |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, pipeline, DataCollatorForLanguageModeling |
|
|
from datasets import load_dataset |
|
|
|
|
|
ds = load_dataset("kaifkhaan/roast") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("distilgpt2") |
|
|
model = AutoModelForCausalLM.from_pretrained("distilgpt2") |
|
|
|
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
model.config.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
def preprocess(batch): |
|
|
|
|
|
texts = [f"{prompt} -> {response}" for prompt, response in zip(batch["User"], batch["Roasting Bot"])] |
|
|
|
|
|
|
|
|
encoded = tokenizer( |
|
|
texts, |
|
|
truncation=True, |
|
|
max_length=128, |
|
|
padding="max_length" |
|
|
) |
|
|
|
|
|
|
|
|
encoded["labels"] = encoded["input_ids"].copy() |
|
|
return encoded |
|
|
|
|
|
|
|
|
|
|
|
tokenized_ds = ds.map(preprocess, batched=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="./roastbot", |
|
|
per_device_train_batch_size=8, |
|
|
num_train_epochs=3, |
|
|
logging_dir="./logs", |
|
|
save_steps=500, |
|
|
report_to="none" |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_ds["train"], |
|
|
|
|
|
) |
|
|
|
|
|
print("Starting training... 🏋️") |
|
|
trainer.train() |
|
|
print("Training complete! ✅") |
|
|
|
|
|
roast_pipeline = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer |
|
|
) |
|
|
|
|
|
def roast_me(text): |
|
|
prompt = f"{text} ->" |
|
|
|
|
|
roast = roast_pipeline(prompt, max_length=50, do_sample=True, pad_token_id=tokenizer.eos_token_id)[0]["generated_text"] |
|
|
return roast.split("->")[-1].strip() |
|
|
|
|
|
gr.Interface( |
|
|
fn=roast_me, |
|
|
inputs="text", |
|
|
outputs="text", |
|
|
title="The Very Good Bot", |
|
|
description="The bot will converse with you in a " |
|
|
).launch() |