File size: 2,309 Bytes
cfd24c9
 
 
 
 
 
 
 
 
 
76b934c
 
 
 
 
cfd24c9
 
c3f9f6f
 
 
 
 
76b934c
c3f9f6f
cfd24c9
76b934c
cfd24c9
 
c3f9f6f
 
76b934c
 
 
cfd24c9
 
493acea
cfd24c9
5e03a3e
cfd24c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e03a3e
cfd24c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3f9f6f
 
cfd24c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
# --- FIX 1: Added 'pipeline' and 'DataCollatorForLanguageModeling' to imports ---
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, pipeline, DataCollatorForLanguageModeling
from datasets import load_dataset

ds = load_dataset("kaifkhaan/roast")

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")

# 🩹 Fix for padding and GPT-2 compatibility
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id


tokenizer.pad_token = tokenizer.eos_token

def preprocess(batch):
    # Create a list of formatted strings for the entire batch
    texts = [f"{prompt} -> {response}" for prompt, response in zip(batch["User"], batch["Roasting Bot"])]
    
    # Tokenize the entire list of texts at once
    encoded = tokenizer(
        texts,
        truncation=True,
        max_length=128,
        padding="max_length"
    )
    
    # Create labels for the whole batch
    encoded["labels"] = encoded["input_ids"].copy()
    return encoded


# Map the preprocessing function to the dataset
tokenized_ds = ds.map(preprocess, batched=True)

# data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./roastbot",
    per_device_train_batch_size=8,
    num_train_epochs=3,
    logging_dir="./logs",
    save_steps=500,
    report_to="none" # Add this to disable wandb/tensorboard logging if not configured
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    # data_collator=data_collator
)

print("Starting training... 🏋️")
trainer.train()
print("Training complete! ✅")

roast_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer
)

def roast_me(text):
    prompt = f"{text} ->"
    # Generate the roast
    roast = roast_pipeline(prompt, max_length=50, do_sample=True, pad_token_id=tokenizer.eos_token_id)[0]["generated_text"]
    return roast.split("->")[-1].strip()

gr.Interface(
    fn=roast_me,
    inputs="text",
    outputs="text",
    title="The Very Good Bot",
    description="The bot will converse with you in a "
).launch()