CrystalRaindropsFall commited on
Commit
8205a2e
ยท
verified ยท
1 Parent(s): b41c1f0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +417 -0
app.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ from peft import PeftModel
5
+ import re
6
+ import json
7
+ from pathlib import Path
8
+
9
+ # ==================== CONFIGURATION ====================
10
+
11
+ # Base models
12
+ BASE_MODELS = {
13
+ "PHI-2 (2.7B)": "microsoft/phi-2",
14
+ "SmolLM2 (135M)": "HuggingFaceTB/SmolLM2-135M",
15
+ }
16
+
17
+ # Adapter configurations - update with your HuggingFace username
18
+ # Format: "username/repo-name" or local path
19
+ ADAPTERS = {
20
+ "PHI-2 (2.7B)": {
21
+ "No Fine-tuning (Base Model)": None,
22
+ "Baseline Fine-tuned": "CrystalRaindropsFall/phi2-gsm8k-baseline",
23
+ "Curriculum: Answer Length": "CrystalRaindropsFall/phi2-gsm8k-curriculum-answer-length",
24
+ "Curriculum: Complexity Score": "CrystalRaindropsFall/phi2-gsm8k-curriculum-complexity",
25
+ },
26
+ "SmolLM2 (135M)": {
27
+ "No Fine-tuning (Base Model)": None,
28
+ "Baseline Fine-tuned": "CrystalRaindropsFall/smolLM2-gsm8k-baseline",
29
+ "Curriculum: Answer Length": "CrystalRaindropsFall/smolLM2-gsm8k-curriculum-answer-length",
30
+ "Curriculum: Complexity Score": "CrystalRaindropsFall/smolLM2-gsm8k-curriculum-complexity",
31
+ },
32
+ }
33
+
34
+ # Sample math problems
35
+ SAMPLE_PROBLEMS = [
36
+ "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
37
+ "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?",
38
+ "Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?",
39
+ "James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?",
40
+ "A store sells pencils for $0.50 each and notebooks for $3.00 each. If Sarah buys 6 pencils and 4 notebooks, how much does she spend in total?",
41
+ "Mike has 45 apples. He gives 1/3 of them to his friend and then buys 12 more apples. How many apples does Mike have now?",
42
+ "A train travels 120 miles in 2 hours. At the same speed, how far will it travel in 5 hours?",
43
+ ]
44
+
45
+ # ==================== MODEL LOADING ====================
46
+
47
+
48
+ class ModelCache:
49
+ """Cache loaded models to avoid reloading"""
50
+
51
+ def __init__(self):
52
+ self.current_base = None
53
+ self.current_adapter = None
54
+ self.model = None
55
+ self.tokenizer = None
56
+ self.pipe = None
57
+
58
+ def load_model(self, base_model_name, adapter_path=None):
59
+ """Load model with optional adapter"""
60
+ cache_key = f"{base_model_name}_{adapter_path}"
61
+ current_key = f"{self.current_base}_{self.current_adapter}"
62
+
63
+ # Return cached if same
64
+ if cache_key == current_key and self.pipe is not None:
65
+ return self.pipe
66
+
67
+ # Clear old model
68
+ if self.model is not None:
69
+ del self.model
70
+ del self.tokenizer
71
+ del self.pipe
72
+ torch.cuda.empty_cache()
73
+
74
+ print(f"Loading {base_model_name}...")
75
+
76
+ # Load tokenizer
77
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
78
+ if tokenizer.pad_token is None:
79
+ tokenizer.pad_token = tokenizer.eos_token
80
+ tokenizer.pad_token_id = tokenizer.eos_token_id
81
+ tokenizer.padding_side = "left"
82
+
83
+ # Load base model
84
+ model = AutoModelForCausalLM.from_pretrained(
85
+ base_model_name,
86
+ device_map="auto",
87
+ torch_dtype=torch.float16,
88
+ )
89
+
90
+ # Load adapter if specified
91
+ if adapter_path:
92
+ print(f"Loading adapter from {adapter_path}...")
93
+ if Path(adapter_path).exists():
94
+ # Local path
95
+ model = PeftModel.from_pretrained(model, adapter_path)
96
+ else:
97
+ # HuggingFace path
98
+ try:
99
+ model = PeftModel.from_pretrained(model, adapter_path)
100
+ except Exception as e:
101
+ print(f"Warning: Could not load adapter from {adapter_path}: {e}")
102
+ print("Using base model only")
103
+
104
+ # Create pipeline
105
+ pipe = pipeline(
106
+ "text-generation",
107
+ model=model,
108
+ tokenizer=tokenizer,
109
+ max_new_tokens=512,
110
+ do_sample=False, # Deterministic for math
111
+ pad_token_id=tokenizer.pad_token_id,
112
+ )
113
+
114
+ # Cache
115
+ self.current_base = base_model_name
116
+ self.current_adapter = adapter_path
117
+ self.model = model
118
+ self.tokenizer = tokenizer
119
+ self.pipe = pipe
120
+
121
+ return pipe
122
+
123
+
124
+ # Global cache
125
+ model_cache = ModelCache()
126
+
127
+ # ==================== HELPER FUNCTIONS ====================
128
+
129
+
130
+ def extract_answer(text):
131
+ """Extract the final numerical answer from generated text"""
132
+ # Look for #### format (GSM8K style)
133
+ match = re.search(r"####\s*(-?\d+\.?\d*)", text)
134
+ if match:
135
+ return match.group(1).rstrip(".")
136
+
137
+ # Fallback: find last number
138
+ numbers = re.findall(r"-?\d+\.?\d*", text)
139
+ if numbers:
140
+ return numbers[-1].rstrip(".")
141
+
142
+ return "No answer found"
143
+
144
+
145
+ def format_solution(generated_text, question):
146
+ """Format the solution for display"""
147
+ # Remove the question from the output (model echoes it)
148
+ solution = generated_text.replace(f"Question: {question}\nAnswer:", "").strip()
149
+
150
+ # Extract answer
151
+ final_answer = extract_answer(generated_text)
152
+
153
+ return solution, final_answer
154
+
155
+
156
+ # ==================== GRADIO INTERFACE ====================
157
+
158
+
159
+ def solve_math_problem(base_model, adapter_choice, question, max_tokens, temperature):
160
+ """Main function to solve math problems"""
161
+ try:
162
+ # Get model path
163
+ base_model_path = BASE_MODELS[base_model]
164
+ adapter_path = ADAPTERS[base_model].get(adapter_choice)
165
+
166
+ # Load model
167
+ pipe = model_cache.load_model(base_model_path, adapter_path)
168
+
169
+ # Format prompt
170
+ prompt = f"Question: {question}\nAnswer:"
171
+
172
+ # Generate
173
+ outputs = pipe(
174
+ prompt,
175
+ max_new_tokens=max_tokens,
176
+ do_sample=temperature > 0,
177
+ temperature=temperature if temperature > 0 else None,
178
+ )
179
+
180
+ generated_text = outputs[0]["generated_text"]
181
+
182
+ # Format output
183
+ solution, final_answer = format_solution(generated_text, question)
184
+
185
+ # Create formatted output
186
+ output = f"""### Solution Steps:
187
+ {solution}
188
+
189
+ ### Final Answer: **{final_answer}**
190
+ """
191
+ return output
192
+
193
+ except Exception as e:
194
+ return f"โŒ Error: {str(e)}\n\nPlease check that the model and adapter are correctly loaded."
195
+
196
+
197
+ def update_adapter_choices(base_model):
198
+ """Update adapter dropdown based on selected base model"""
199
+ adapters = list(ADAPTERS[base_model].keys())
200
+ return gr.Dropdown(choices=adapters, value=adapters[0])
201
+
202
+
203
+ def load_sample_problem(sample_idx):
204
+ """Load a sample problem"""
205
+ if sample_idx is None or sample_idx >= len(SAMPLE_PROBLEMS):
206
+ return SAMPLE_PROBLEMS[0]
207
+ return SAMPLE_PROBLEMS[sample_idx]
208
+
209
+
210
+ # ==================== BUILD INTERFACE ====================
211
+
212
+
213
+ def create_demo():
214
+ """Create the Gradio interface"""
215
+
216
+ with gr.Blocks(
217
+ theme=gr.themes.Soft(), title="Curriculum Design Matters: Math Reasoning Demo"
218
+ ) as demo:
219
+ gr.Markdown(
220
+ """
221
+ # ๐ŸŽ“ Curriculum Design Matters: Training LLMs for Math Reasoning
222
+
223
+ <div style="font-size: 1.2em; line-height: 1.6;">
224
+
225
+ Compare how different training strategies affect mathematical reasoning in language models.
226
+
227
+ **Key Finding:** Not all curricula are equalโ€”wrong curriculum design can hurt performance!
228
+
229
+ </div>
230
+ """,
231
+ elem_classes="header",
232
+ )
233
+
234
+ with gr.Row():
235
+ with gr.Column():
236
+ question_input = gr.Textbox(
237
+ lines=5,
238
+ placeholder="Enter a math word problem here...",
239
+ label="Enter Your Math Problem",
240
+ value=SAMPLE_PROBLEMS[0],
241
+ show_label=True,
242
+ )
243
+
244
+ with gr.Accordion("๐Ÿ“š Or Choose a Sample Problem", open=False):
245
+ sample_dropdown = gr.Dropdown(
246
+ choices=[
247
+ f"Sample {i + 1}: {prob[:50]}..."
248
+ for i, prob in enumerate(SAMPLE_PROBLEMS)
249
+ ],
250
+ value=f"Sample 1: {SAMPLE_PROBLEMS[0][:50]}...",
251
+ label="Sample Problems",
252
+ scale=3,
253
+ )
254
+ load_sample_btn = gr.Button("๐Ÿ“ฅ Load Selected Sample", size="sm")
255
+
256
+ solve_btn = gr.Button("๐Ÿงฎ Solve Problem", variant="primary", size="lg")
257
+
258
+ gr.Markdown("### ๐Ÿ’ก Solution")
259
+
260
+ output_text = gr.Markdown(
261
+ value="*Solution will appear here after you click 'Solve Problem'...*",
262
+ label="Generated Solution",
263
+ )
264
+
265
+ gr.Markdown("### โš™๏ธ Model Selection")
266
+
267
+ base_model = gr.Dropdown(
268
+ choices=list(BASE_MODELS.keys()),
269
+ value=list(BASE_MODELS.keys())[0],
270
+ label="Base Model",
271
+ info="Choose the foundation model",
272
+ )
273
+
274
+ adapter_choice = gr.Dropdown(
275
+ choices=list(ADAPTERS[list(BASE_MODELS.keys())[0]].keys()),
276
+ value=list(ADAPTERS[list(BASE_MODELS.keys())[0]].keys())[0],
277
+ label="Fine-tuning Strategy",
278
+ info="Choose training method",
279
+ )
280
+
281
+ with gr.Accordion("๐ŸŽ›๏ธ Advanced Settings", open=False):
282
+ max_tokens = gr.Slider(
283
+ minimum=128,
284
+ maximum=512,
285
+ value=256,
286
+ step=32,
287
+ label="Max New Tokens",
288
+ info="Maximum length of solution",
289
+ )
290
+
291
+ temperature = gr.Slider(
292
+ minimum=0.0,
293
+ maximum=1.0,
294
+ value=0.0,
295
+ step=0.1,
296
+ label="Temperature",
297
+ info="0 = deterministic, >0 = creative",
298
+ )
299
+
300
+ # ==================== EVENT HANDLERS ====================
301
+
302
+ # Update adapters when base model changes
303
+ base_model.change(
304
+ fn=update_adapter_choices, inputs=[base_model], outputs=[adapter_choice]
305
+ )
306
+
307
+ # Load sample problem
308
+ def load_sample_fn(sample_name):
309
+ idx = int(sample_name.split()[1].split(":")[0]) - 1
310
+ return SAMPLE_PROBLEMS[idx]
311
+
312
+ load_sample_btn.click(
313
+ fn=load_sample_fn, inputs=[sample_dropdown], outputs=[question_input]
314
+ )
315
+
316
+ # Solve problem
317
+ solve_btn.click(
318
+ fn=solve_math_problem,
319
+ inputs=[
320
+ base_model,
321
+ adapter_choice,
322
+ question_input,
323
+ max_tokens,
324
+ temperature,
325
+ ],
326
+ outputs=[output_text],
327
+ )
328
+
329
+ # ==================== BOTTOM INFO ====================
330
+
331
+ gr.Markdown("---")
332
+
333
+ with gr.Accordion("๐Ÿ“Š Experimental Results & Key Findings", open=False):
334
+ gr.Markdown("""
335
+ ### Results Summary
336
+
337
+ **PHI-2 (2.7B Parameters):**
338
+ - Baseline: 60.16% accuracy
339
+ - Curriculum (Answer Length): 59.38% (-0.78%) โŒ
340
+ - Curriculum (Complexity Score): 62.50% (+2.34%) โœ…
341
+
342
+ **SmolLM2 (135M Parameters):**
343
+ - Baseline: 2.15% accuracy
344
+ - Curriculum (Answer Length): 2.73% (+0.58%)
345
+ - Curriculum (Complexity Score): 2.93% (+0.78%)
346
+
347
+ ### Key Insights
348
+
349
+ 1. **Curriculum design is critical** - Wrong curriculum hurts performance
350
+ 2. **Complexity matters more than length** - Steps ร— operations beats simple answer length
351
+ 3. **Model size affects benefits** - Larger models benefit more from curriculum learning
352
+ 4. **Progressive difficulty works** - Easy โ†’ Normal โ†’ Difficult stages improve learning
353
+ """)
354
+
355
+ with gr.Accordion("๐Ÿ“š Training Methods Explained", open=False):
356
+ gr.Markdown("""
357
+ **No Fine-tuning:** Base model without any training on GSM8K
358
+
359
+ **Baseline Fine-tuned:** Standard fine-tuning on all problems at once
360
+ - All difficulty levels mixed together
361
+ - 3 epochs on full dataset
362
+
363
+ **Curriculum: Answer Length:** Progressive training based on solution length
364
+ - Stage 1 (Easy): Short solutions (< 100 chars)
365
+ - Stage 2 (Normal): Medium solutions (100-200 chars)
366
+ - Stage 3 (Difficult): Long solutions (> 200 chars)
367
+ - Result: Performance decreased! โŒ
368
+
369
+ **Curriculum: Complexity Score:** Progressive training based on steps ร— operations
370
+ - Stage 1 (Easy): Few steps, simple operations
371
+ - Stage 2 (Normal): Moderate complexity
372
+ - Stage 3 (Difficult): Many steps, complex operations
373
+ - Result: Performance improved! โœ…
374
+ """)
375
+
376
+ with gr.Accordion("โ„น๏ธ About This Demo", open=False):
377
+ gr.Markdown("""
378
+ ### Technical Details
379
+
380
+ **Models:**
381
+ - PHI-2: 2.7B parameter model by Microsoft
382
+ - SmolLM2: 135M parameter compact model by HuggingFace
383
+
384
+ **Dataset:** GSM8K (Grade School Math 8K) - 7,473 training and 1,319 test elementary school math word problems
385
+
386
+ **Training Method:** LoRA (Low-Rank Adaptation) fine-tuning
387
+ - Rank: 16, Alpha: 32
388
+ - Target modules: q_proj, k_proj, v_proj, o_proj
389
+ - 3 epochs per curriculum stage
390
+ - Learning rate: 3e-4
391
+
392
+ **Evaluation:** Exact match accuracy on GSM8K test set
393
+
394
+ ### Links & Resources
395
+
396
+ ๐Ÿ”— [GitHub Repository](#) | [Blog Post](#) | [Paper](#) | [Adapters on HuggingFace](#)
397
+
398
+ ### Note
399
+
400
+ โš ๏ธ Models are loaded on-demand and cached in memory. First inference may take 30-60 seconds.
401
+
402
+ Models run on GPU if available, otherwise CPU (slower).
403
+ """)
404
+
405
+ return demo
406
+
407
+
408
+ # ==================== MAIN ====================
409
+
410
+ if __name__ == "__main__":
411
+ demo = create_demo()
412
+ demo.launch(
413
+ share=True, # Set to True to create public link
414
+ server_name="0.0.0.0", # Allow external access
415
+ server_port=7860,
416
+ show_error=True,
417
+ )