project1 / app.py
Neo
all
c3f9f6f
import gradio as gr
# --- FIX 1: Added 'pipeline' and 'DataCollatorForLanguageModeling' to imports ---
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, pipeline, DataCollatorForLanguageModeling
from datasets import load_dataset
ds = load_dataset("kaifkhaan/roast")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
# 🩹 Fix for padding and GPT-2 compatibility
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
def preprocess(batch):
# Create a list of formatted strings for the entire batch
texts = [f"{prompt} -> {response}" for prompt, response in zip(batch["User"], batch["Roasting Bot"])]
# Tokenize the entire list of texts at once
encoded = tokenizer(
texts,
truncation=True,
max_length=128,
padding="max_length"
)
# Create labels for the whole batch
encoded["labels"] = encoded["input_ids"].copy()
return encoded
# Map the preprocessing function to the dataset
tokenized_ds = ds.map(preprocess, batched=True)
# data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Define training arguments
training_args = TrainingArguments(
output_dir="./roastbot",
per_device_train_batch_size=8,
num_train_epochs=3,
logging_dir="./logs",
save_steps=500,
report_to="none" # Add this to disable wandb/tensorboard logging if not configured
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_ds["train"],
# data_collator=data_collator
)
print("Starting training... 🏋️")
trainer.train()
print("Training complete! ✅")
roast_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
def roast_me(text):
prompt = f"{text} ->"
# Generate the roast
roast = roast_pipeline(prompt, max_length=50, do_sample=True, pad_token_id=tokenizer.eos_token_id)[0]["generated_text"]
return roast.split("->")[-1].strip()
gr.Interface(
fn=roast_me,
inputs="text",
outputs="text",
title="The Very Good Bot",
description="The bot will converse with you in a "
).launch()