File size: 5,983 Bytes
9d0903f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset
import gradio as gr

data = [
    {"text": "Riddle: What number becomes zero when you subtract 15 from half of it?\nAnswer: 30"},
    {"text": "Riddle: I am a number that when doubled and then reduced by 20 gives 40.\nAnswer: 30"},
    {"text": "Riddle: If you add 10 to a number and then subtract 5, you get 25.\nAnswer: 20"},
    {"text": "Riddle: I am 15 less than twice my value.\nAnswer: 15"},
    {"text": "Riddle: A number when halved and then increased by 10 becomes 25.\nAnswer: 30"},
    {"text": "Riddle: When you multiply a number by 3 and subtract 9, the result is 18.\nAnswer: 9"},
    {"text": "Riddle: If a number is decreased by 8 and then doubled, you get 14.\nAnswer: 15"},
    {"text": "Riddle: A number when tripled and then increased by 5 equals 20.\nAnswer: 5"},
    {"text": "Riddle: When you add 7 to half of a number, you get 19.\nAnswer: 24"},
    {"text": "Riddle: A number is increased by 9 and then halved to get 15.\nAnswer: 21"},
    {"text": "Riddle: When you subtract 4 from a number and then multiply by 3, the result is 33.\nAnswer: 15"},
    {"text": "Riddle: A number reduced by 6 equals one-third of itself.\nAnswer: 9"},
    {"text": "Riddle: When you double a number and add 10, you get 30.\nAnswer: 10"},
    {"text": "Riddle: A number, when 5 is subtracted and then multiplied by 2, gives 20.\nAnswer: 15"},
    {"text": "Riddle: If a number is multiplied by 4 and then decreased by 8, the result is 24.\nAnswer: 8"},
    {"text": "Riddle: A number, when divided by 2 and then increased by 7, equals 17.\nAnswer: 20"},
    {"text": "Riddle: When you subtract 3 from a number and then square the result, you get 49.\nAnswer: 10"},
    {"text": "Riddle: If 12 is added to a number, the result is three times the number.\nAnswer: 6"},
    {"text": "Riddle: A number increased by 50% equals 27.\nAnswer: 18"},
    {"text": "Riddle: If a number is halved and then 4 is subtracted, the result is 8.\nAnswer: 24"},
    {"text": "Riddle: A number, when 2 is added, becomes twice the original number.\nAnswer: 2"},
    {"text": "Riddle: When you triple a number and subtract 7, the result is 14.\nAnswer: 7"},
    {"text": "Riddle: A number, when reduced by 2 and then divided by 4, gives 5.\nAnswer: 22"},
    {"text": "Riddle: When you add 8 to a number and then multiply by 2, you get 40.\nAnswer: 12"},
    {"text": "Riddle: A number, when doubled, is 16 more than the number itself.\nAnswer: 16"},
    {"text": "Riddle: A number that is increased by 3 and then multiplied by 2 equals 26.\nAnswer: 10"},
    {"text": "Riddle: A number when reduced by 4 and then doubled equals 12.\nAnswer: 10"},
    {"text": "Riddle: If you subtract 2 from a number and then double it, you get 14.\nAnswer: 9"},
    {"text": "Riddle: A number when tripled and decreased by 5 equals 16.\nAnswer: 7"},
    {"text": "Riddle: If you add 5 to a number and then double the result, you get 30.\nAnswer: 10"}
]

dataset = Dataset.from_list(data)

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=128)

tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(["text"])  # remove raw text column if not needed
tokenized_dataset.set_format("torch")

model = GPT2LMHeadModel.from_pretrained("gpt2")
model.resize_token_embeddings(len(tokenizer))  # Adjust for the added pad token

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir="./gpt2-math-riddle",
    overwrite_output_dir=True,
    num_train_epochs=15,                    # Increased epochs
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,          # Simulate a larger batch size
    learning_rate=3e-5,                     # Lower learning rate
    weight_decay=0.01,                      # Optional: add weight decay
    warmup_steps=100,                       # Optional: add warmup steps
    save_steps=500,
    save_total_limit=2,
    logging_steps=50,
    prediction_loss_only=True,
    report_to=[]                           # Disable wandb logging
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

trainer.train()

model.eval()

# Update model config for pad_token_id if not already set
model.config.pad_token_id = tokenizer.eos_token_id

# Gradio UI for testing
def generate_riddle(prompt):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        input_ids,
        max_length=35,             # Further limit the output length
        do_sample=True,            # Enable sampling
        top_k=50,                  # Top-k sampling
        top_p=0.92,                # Nucleus sampling
        temperature=0.5,           # Lower temperature for more deterministic outputs
        repetition_penalty=1.2,    # Penalize repetition
        no_repeat_ngram_size=3,    # Prevent 3-gram repetition
        num_return_sequences=5,
        pad_token_id=tokenizer.eos_token_id
    )
    
    generated_texts = []
    for output in outputs:
        generated_text = tokenizer.decode(output, skip_special_tokens=True)
        if "\nAnswer:" in generated_text:
            parts = generated_text.split("\nAnswer:")
            answer_part = parts[1].split('.')[0] + "."
            generated_text = parts[0] + "\nAnswer:" + answer_part
        generated_texts.append(generated_text)
    
    return generated_texts

iface = gr.Interface(fn=generate_riddle, inputs="text", outputs="text", title="Math Riddle Generator", description="Enter a prompt to generate a riddle.")
iface.launch()