File size: 3,968 Bytes
5ca89a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from transformers import GPT2LMHeadModel, TrainingArguments, Trainer, DataCollatorForLanguageModeling, GPT2Tokenizer
import gradio as gr

# Create a list of 30 emoji math problems with their solutions.
# Format: "Q: [emoji math equation]\nA: [solution]"
data = [
    "Q: 🍎 + 🍎 + 🍎 = 12\nA: 4",
    "Q: 🎲 + 🎲 = 12\nA: 6",
    "Q: πŸš— + πŸš— + πŸš— + πŸš— = 20\nA: 5",
    "Q: 🍌 + 🍌 + 🍌 + 🍌 + 🍌 = 15\nA: 3",
    "Q: πŸ“ + πŸ“ = 8\nA: 4",
    "Q: πŸ• + πŸ• + πŸ• = 18\nA: 6",
    "Q: 🍩 + 🍩 + 🍩 + 🍩 = 20\nA: 5",
    "Q: 🌟 + 🌟 + 🌟 = 9\nA: 3",
    "Q: 🎈 + 🎈 = 14\nA: 7",
    "Q: πŸŽ‚ + πŸŽ‚ + πŸŽ‚ = 15\nA: 5",
    "Q: πŸͺ + πŸͺ + πŸͺ + πŸͺ = 16\nA: 4",
    "Q: 🍭 + 🍭 + 🍭 = 15\nA: 5",
    "Q: 🧁 + 🧁 = 10\nA: 5",
    "Q: πŸ₯‘ + πŸ₯‘ + πŸ₯‘ = 12\nA: 4",
    "Q: πŸ‡ + πŸ‡ = 10\nA: 5",
    "Q: πŸ’ + πŸ’ + πŸ’ = 15\nA: 5",
    "Q: 🍍 + 🍍 = 14\nA: 7",
    "Q: πŸ‰ + πŸ‰ + πŸ‰ + πŸ‰ = 20\nA: 5",
    "Q: πŸ₯­ + πŸ₯­ = 16\nA: 8",
    "Q: 🍈 + 🍈 + 🍈 = 9\nA: 3",
    "Q: πŸ‘ + πŸ‘ + πŸ‘ + πŸ‘ = 20\nA: 5",
    "Q: 🍏 + 🍏 = 10\nA: 5",
    "Q: πŸ‹ + πŸ‹ + πŸ‹ = 12\nA: 4",
    "Q: 🍊 + 🍊 = 10\nA: 5",
    "Q: πŸ₯ + πŸ₯ + πŸ₯ = 15\nA: 5",
    "Q: 🍐 + 🍐 = 8\nA: 4",
    "Q: πŸ† + πŸ† + πŸ† + πŸ† = 16\nA: 4",
    "Q: πŸ₯• + πŸ₯• = 10\nA: 5",
    "Q: 🌽 + 🌽 + 🌽 = 9\nA: 3",
    "Q: πŸ₯” + πŸ₯” + πŸ₯” + πŸ₯” = 20\nA: 5"
]

# For training with Hugging Face's datasets, we create a dictionary.
from datasets import Dataset
dataset = Dataset.from_dict({"text": data})

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(example):
    return tokenizer(example["text"], truncation=True, max_length=128, padding="max_length")

tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

# Load GPT-2 model
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.config.pad_token_id = tokenizer.eos_token_id

# Create a data collator for language modeling that will handle padding dynamically
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./emoji_math_model",
    overwrite_output_dir=True,
    num_train_epochs=8,
    per_device_train_batch_size=4,
    save_steps=50,
    save_total_limit=2,
    logging_steps=10,
    learning_rate=1e-5
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

# Start training
trainer.train()

import logging
logging.getLogger("transformers").setLevel(logging.ERROR)  # Suppress transformer warnings

import re

def generate_single_answer(prompt, max_new_tokens=10):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)

    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=False,
        repetition_penalty=2.0
    )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

    if prompt in generated_text:
        generated_text = generated_text.split(prompt, 1)[1]

    match = re.search(r'\b(\d+)\b', generated_text)
    if match:
        answer = match.group(1)
    else:
        answer = generated_text.strip()

    return answer

# Gradio UI
def gradio_interface(prompt):
    return generate_single_answer(prompt)

iface = gr.Interface(fn=gradio_interface, inputs="text", outputs="text", title="Emoji Math Solver", description="Enter an emoji math problem to get the answer.")
iface.launch()