import gradio as gr # --- FIX 1: Added 'pipeline' and 'DataCollatorForLanguageModeling' to imports --- from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, pipeline, DataCollatorForLanguageModeling from datasets import load_dataset ds = load_dataset("kaifkhaan/roast") tokenizer = AutoTokenizer.from_pretrained("distilgpt2") model = AutoModelForCausalLM.from_pretrained("distilgpt2") # 🩹 Fix for padding and GPT-2 compatibility tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token = tokenizer.eos_token def preprocess(batch): # Create a list of formatted strings for the entire batch texts = [f"{prompt} -> {response}" for prompt, response in zip(batch["User"], batch["Roasting Bot"])] # Tokenize the entire list of texts at once encoded = tokenizer( texts, truncation=True, max_length=128, padding="max_length" ) # Create labels for the whole batch encoded["labels"] = encoded["input_ids"].copy() return encoded # Map the preprocessing function to the dataset tokenized_ds = ds.map(preprocess, batched=True) # data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Define training arguments training_args = TrainingArguments( output_dir="./roastbot", per_device_train_batch_size=8, num_train_epochs=3, logging_dir="./logs", save_steps=500, report_to="none" # Add this to disable wandb/tensorboard logging if not configured ) # Initialize the Trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_ds["train"], # data_collator=data_collator ) print("Starting training... 🏋️") trainer.train() print("Training complete! ✅") roast_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer ) def roast_me(text): prompt = f"{text} ->" # Generate the roast roast = roast_pipeline(prompt, max_length=50, do_sample=True, pad_token_id=tokenizer.eos_token_id)[0]["generated_text"] return roast.split("->")[-1].strip() gr.Interface( fn=roast_me, inputs="text", outputs="text", title="The Very Good Bot", description="The bot will converse with you in a " ).launch()