LonewolfT141 commited on
Commit
5ca89a9
Β·
verified Β·
1 Parent(s): a2a31ae

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel, TrainingArguments, Trainer, DataCollatorForLanguageModeling, GPT2Tokenizer
2
+ import gradio as gr
3
+
4
+ # Create a list of 30 emoji math problems with their solutions.
5
+ # Format: "Q: [emoji math equation]\nA: [solution]"
6
+ data = [
7
+ "Q: 🍎 + 🍎 + 🍎 = 12\nA: 4",
8
+ "Q: 🎲 + 🎲 = 12\nA: 6",
9
+ "Q: πŸš— + πŸš— + πŸš— + πŸš— = 20\nA: 5",
10
+ "Q: 🍌 + 🍌 + 🍌 + 🍌 + 🍌 = 15\nA: 3",
11
+ "Q: πŸ“ + πŸ“ = 8\nA: 4",
12
+ "Q: πŸ• + πŸ• + πŸ• = 18\nA: 6",
13
+ "Q: 🍩 + 🍩 + 🍩 + 🍩 = 20\nA: 5",
14
+ "Q: 🌟 + 🌟 + 🌟 = 9\nA: 3",
15
+ "Q: 🎈 + 🎈 = 14\nA: 7",
16
+ "Q: πŸŽ‚ + πŸŽ‚ + πŸŽ‚ = 15\nA: 5",
17
+ "Q: πŸͺ + πŸͺ + πŸͺ + πŸͺ = 16\nA: 4",
18
+ "Q: 🍭 + 🍭 + 🍭 = 15\nA: 5",
19
+ "Q: 🧁 + 🧁 = 10\nA: 5",
20
+ "Q: πŸ₯‘ + πŸ₯‘ + πŸ₯‘ = 12\nA: 4",
21
+ "Q: πŸ‡ + πŸ‡ = 10\nA: 5",
22
+ "Q: πŸ’ + πŸ’ + πŸ’ = 15\nA: 5",
23
+ "Q: 🍍 + 🍍 = 14\nA: 7",
24
+ "Q: πŸ‰ + πŸ‰ + πŸ‰ + πŸ‰ = 20\nA: 5",
25
+ "Q: πŸ₯­ + πŸ₯­ = 16\nA: 8",
26
+ "Q: 🍈 + 🍈 + 🍈 = 9\nA: 3",
27
+ "Q: πŸ‘ + πŸ‘ + πŸ‘ + πŸ‘ = 20\nA: 5",
28
+ "Q: 🍏 + 🍏 = 10\nA: 5",
29
+ "Q: πŸ‹ + πŸ‹ + πŸ‹ = 12\nA: 4",
30
+ "Q: 🍊 + 🍊 = 10\nA: 5",
31
+ "Q: πŸ₯ + πŸ₯ + πŸ₯ = 15\nA: 5",
32
+ "Q: 🍐 + 🍐 = 8\nA: 4",
33
+ "Q: πŸ† + πŸ† + πŸ† + πŸ† = 16\nA: 4",
34
+ "Q: πŸ₯• + πŸ₯• = 10\nA: 5",
35
+ "Q: 🌽 + 🌽 + 🌽 = 9\nA: 3",
36
+ "Q: πŸ₯” + πŸ₯” + πŸ₯” + πŸ₯” = 20\nA: 5"
37
+ ]
38
+
39
+ # For training with Hugging Face's datasets, we create a dictionary.
40
+ from datasets import Dataset
41
+ dataset = Dataset.from_dict({"text": data})
42
+
43
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
44
+ tokenizer.pad_token = tokenizer.eos_token
45
+
46
+ def tokenize_function(example):
47
+ return tokenizer(example["text"], truncation=True, max_length=128, padding="max_length")
48
+
49
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
50
+ tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
51
+
52
+ # Load GPT-2 model
53
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
54
+ model.config.pad_token_id = tokenizer.eos_token_id
55
+
56
+ # Create a data collator for language modeling that will handle padding dynamically
57
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
58
+
59
+ # Define training arguments
60
+ training_args = TrainingArguments(
61
+ output_dir="./emoji_math_model",
62
+ overwrite_output_dir=True,
63
+ num_train_epochs=8,
64
+ per_device_train_batch_size=4,
65
+ save_steps=50,
66
+ save_total_limit=2,
67
+ logging_steps=10,
68
+ learning_rate=1e-5
69
+ )
70
+
71
+ # Initialize the Trainer
72
+ trainer = Trainer(
73
+ model=model,
74
+ args=training_args,
75
+ train_dataset=tokenized_dataset,
76
+ data_collator=data_collator,
77
+ )
78
+
79
+ # Start training
80
+ trainer.train()
81
+
82
+ import logging
83
+ logging.getLogger("transformers").setLevel(logging.ERROR) # Suppress transformer warnings
84
+
85
+ import re
86
+
87
+ def generate_single_answer(prompt, max_new_tokens=10):
88
+ inputs = tokenizer(prompt, return_tensors="pt")
89
+ input_ids = inputs["input_ids"].to(model.device)
90
+ attention_mask = inputs["attention_mask"].to(model.device)
91
+
92
+ output = model.generate(
93
+ input_ids,
94
+ attention_mask=attention_mask,
95
+ max_new_tokens=max_new_tokens,
96
+ num_return_sequences=1,
97
+ eos_token_id=tokenizer.eos_token_id,
98
+ pad_token_id=tokenizer.eos_token_id,
99
+ do_sample=False,
100
+ repetition_penalty=2.0
101
+ )
102
+
103
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
104
+
105
+ if prompt in generated_text:
106
+ generated_text = generated_text.split(prompt, 1)[1]
107
+
108
+ match = re.search(r'\b(\d+)\b', generated_text)
109
+ if match:
110
+ answer = match.group(1)
111
+ else:
112
+ answer = generated_text.strip()
113
+
114
+ return answer
115
+
116
+ # Gradio UI
117
+ def gradio_interface(prompt):
118
+ return generate_single_answer(prompt)
119
+
120
+ iface = gr.Interface(fn=gradio_interface, inputs="text", outputs="text", title="Emoji Math Solver", description="Enter an emoji math problem to get the answer.")
121
+ iface.launch()