mindchain commited on
Commit
d261b4f
Β·
verified Β·
1 Parent(s): 95008ad

Upload train_arithmetic_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_arithmetic_v2.py +296 -0
train_arithmetic_v2.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO + RLVR Training for Simple Arithmetic - v2
4
+ Task: 2-digit addition and subtraction
5
+ Base Model: Qwen/Qwen3-0.6B-Base
6
+
7
+ Improvements:
8
+ - Better reward function with debugging
9
+ - Force EOS token in generation
10
+ - Per-step evaluation
11
+ - Clear tracking metrics
12
+ """
13
+
14
+ import os
15
+ import re
16
+ import random
17
+ import torch
18
+ from datasets import Dataset
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
20
+ from trl import GRPOConfig, GRPOTrainer
21
+
22
+ # ============================================================================
23
+ # CONFIG
24
+ # ============================================================================
25
+
26
+ BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
27
+ OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v2"
28
+ MAX_STEPS = 20
29
+ NUM_SAMPLES = 500
30
+ EVAL_SAMPLES = 20
31
+ EVAL_EVERY = 5 # Evaluate every N steps
32
+
33
+ # ============================================================================
34
+ # DATA GENERATION
35
+ # ============================================================================
36
+
37
+ def generate_arithmetic_samples(n_samples):
38
+ """Generate simple arithmetic problems"""
39
+ samples = []
40
+ for _ in range(n_samples):
41
+ op = random.choice(['+', '-'])
42
+
43
+ if op == '+':
44
+ a = random.randint(10, 99)
45
+ b = random.randint(10, 99)
46
+ answer = a + b
47
+ problem = f"{a} + {b} = ?"
48
+ else:
49
+ a = random.randint(20, 99)
50
+ b = random.randint(10, a-1)
51
+ answer = a - b
52
+ problem = f"{a} - {b} = ?"
53
+
54
+ samples.append({
55
+ 'prompt': f"Solve: {problem}\nAnswer:",
56
+ 'answer': str(answer),
57
+ 'ground_truth': str(answer), # Also provide ground_truth for GRPO
58
+ })
59
+
60
+ return samples
61
+
62
+ # ============================================================================
63
+ # REWARD FUNCTION (with debugging)
64
+ # ============================================================================
65
+
66
+ def reward_func(completions, prompts=None, **kwargs):
67
+ """
68
+ Reward function for arithmetic with debugging.
69
+ """
70
+ # Try multiple column names for ground truth
71
+ answers = None
72
+ for key in ['answer', 'ground_truth', 'solution', 'label']:
73
+ if key in kwargs and kwargs[key] is not None:
74
+ answers = kwargs[key]
75
+ break
76
+
77
+ if answers is None:
78
+ print("⚠️ WARNING: No ground truth found in kwargs!")
79
+ print(f" Available keys: {list(kwargs.keys())}")
80
+ return [0.0] * len(completions)
81
+
82
+ rewards = []
83
+ debug_samples = min(2, len(completions)) # Debug first 2 samples
84
+
85
+ for i, (completion, truth) in enumerate(zip(completions, answers)):
86
+ # Handle list format (conversational)
87
+ if isinstance(completion, list):
88
+ text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
89
+ else:
90
+ text = str(completion)
91
+
92
+ # Extract the last number
93
+ numbers = re.findall(r'-?\d+\.?\d*', text)
94
+ if numbers:
95
+ predicted = numbers[-1].strip()
96
+ else:
97
+ predicted = ""
98
+
99
+ # Exact match reward
100
+ is_correct = predicted == str(truth).strip()
101
+ rewards.append(1.0 if is_correct else 0.0)
102
+
103
+ # Debug first few samples
104
+ if i < debug_samples:
105
+ status = "βœ…" if is_correct else "❌"
106
+ print(f" [{i+1}] {status} Truth={truth} | Pred={predicted} | Text={text[:80]}...")
107
+
108
+ return rewards
109
+
110
+ # ============================================================================
111
+ # EVALUATION
112
+ # ============================================================================
113
+
114
+ def evaluate_model(model, tokenizer, n_samples=EVAL_SAMPLES, step=0):
115
+ """Evaluate model performance"""
116
+ print(f"\n{'='*70}")
117
+ print(f"πŸ“Š EVALUATION @ Step {step}")
118
+ print(f"{'='*70}")
119
+
120
+ test_samples = generate_arithmetic_samples(n_samples)
121
+ correct = 0
122
+
123
+ model.eval()
124
+ with torch.no_grad():
125
+ for i, sample in enumerate(test_samples):
126
+ inputs = tokenizer(sample['prompt'], return_tensors='pt')
127
+
128
+ if hasattr(model, 'device') and model.device is not None:
129
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
130
+
131
+ outputs = model.generate(
132
+ **inputs,
133
+ max_new_tokens=30,
134
+ do_sample=False,
135
+ pad_token_id=tokenizer.eos_token_id,
136
+ eos_token_id=tokenizer.eos_token_id,
137
+ )
138
+
139
+ input_ids = inputs.get('input_ids')
140
+ if input_ids is not None and hasattr(input_ids, 'shape'):
141
+ response = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
142
+ else:
143
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
144
+
145
+ # Extract answer
146
+ numbers = re.findall(r'-?\d+\.?\d*', response)
147
+ predicted = numbers[-1].strip() if numbers else ""
148
+ truth = sample['answer'].strip()
149
+
150
+ is_correct = predicted == truth
151
+ if is_correct:
152
+ correct += 1
153
+
154
+ status = "βœ…" if is_correct else "❌"
155
+ print(f"[{i+1}] {status} {truth} | Pred: {predicted} | {response[:40]}...")
156
+
157
+ accuracy = correct / n_samples * 100
158
+ print(f"\nπŸ“Š Accuracy: {accuracy:.1f}% ({correct}/{n_samples})")
159
+ print(f"{'='*70}\n")
160
+
161
+ model.train()
162
+ return accuracy
163
+
164
+ # ============================================================================
165
+ # CALLBACK FOR PER-STEP EVAL
166
+ # ============================================================================
167
+
168
+ from transformers import TrainerCallback
169
+
170
+ class EvalCallback(TrainerCallback):
171
+ def __init__(self, model, tokenizer, eval_every=EVAL_EVERY):
172
+ self.model = model
173
+ self.tokenizer = tokenizer
174
+ self.eval_every = eval_every
175
+ self.accuracies = []
176
+
177
+ def on_step_end(self, args, state, control, **kwargs):
178
+ if state.global_step > 0 and state.global_step % self.eval_every == 0:
179
+ acc = evaluate_model(self.model, self.tokenizer, step=state.global_step)
180
+ self.accuracies.append((state.global_step, acc))
181
+
182
+ # Print summary
183
+ print(f"\nπŸ“ˆ Progress Summary:")
184
+ for step, accuracy in self.accuracies:
185
+ print(f" Step {step}: {accuracy:.1f}%")
186
+ print()
187
+
188
+ # ============================================================================
189
+ # MAIN TRAINING
190
+ # ============================================================================
191
+
192
+ def main():
193
+ print("="*70)
194
+ print("πŸ”’ GRPO + RLVR Arithmetic Training - v2")
195
+ print("="*70)
196
+ print(f"Base Model: {BASE_MODEL}")
197
+ print(f"Output: {OUTPUT_MODEL}")
198
+ print(f"Steps: {MAX_STEPS}")
199
+ print(f"Eval every: {EVAL_EVERY} steps")
200
+ print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
201
+ print("="*70 + "\n")
202
+
203
+ # Load model and tokenizer
204
+ print("πŸ“¦ Loading model and tokenizer...")
205
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
206
+
207
+ # Ensure pad token is set
208
+ if tokenizer.pad_token is None:
209
+ tokenizer.pad_token = tokenizer.eos_token
210
+ print(f" Set pad_token to eos_token: {tokenizer.eos_token}")
211
+
212
+ model = AutoModelForCausalLM.from_pretrained(
213
+ BASE_MODEL,
214
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
215
+ )
216
+
217
+ # Resize embeddings if needed
218
+ model.resize_token_embeddings(len(tokenizer))
219
+
220
+ # Initial evaluation
221
+ initial_acc = evaluate_model(model, tokenizer, step=0)
222
+
223
+ # Generate training data
224
+ print("πŸ“Š Generating training data...")
225
+ train_samples = generate_arithmetic_samples(NUM_SAMPLES)
226
+ train_dataset = Dataset.from_list(train_samples)
227
+ print(f"βœ… {len(train_dataset)} training samples\n")
228
+
229
+ # GRPO Config
230
+ is_cpu = not torch.cuda.is_available()
231
+ training_args = GRPOConfig(
232
+ output_dir="./outputs",
233
+ max_steps=MAX_STEPS,
234
+ per_device_train_batch_size=2,
235
+ num_generations=2,
236
+ learning_rate=2e-4,
237
+ beta=0.0, # No KL penalty for arithmetic
238
+ bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
239
+ fp16=False,
240
+ gradient_checkpointing=not is_cpu,
241
+ optim="adamw_torch" if is_cpu else "adamw_8bit",
242
+ logging_steps=1,
243
+ save_steps=MAX_STEPS,
244
+ push_to_hub=False,
245
+ report_to="none",
246
+ # Force EOS in generation
247
+ generation_config=GenerationConfig(
248
+ max_new_tokens=30,
249
+ do_sample=True,
250
+ temperature=0.7,
251
+ pad_token_id=tokenizer.eos_token_id,
252
+ eos_token_id=tokenizer.eos_token_id,
253
+ ),
254
+ )
255
+
256
+ # Eval callback
257
+ eval_callback = EvalCallback(model, tokenizer, eval_every=EVAL_EVERY)
258
+
259
+ print("πŸš€ Starting GRPO Training...")
260
+ print(f"Initial accuracy: {initial_acc:.1f}%\n")
261
+
262
+ # Train
263
+ trainer = GRPOTrainer(
264
+ model=model,
265
+ args=training_args,
266
+ train_dataset=train_dataset,
267
+ reward_funcs=[reward_func],
268
+ callbacks=[eval_callback],
269
+ )
270
+
271
+ trainer.train()
272
+
273
+ # Final evaluation
274
+ final_acc = evaluate_model(model, tokenizer, step=MAX_STEPS)
275
+
276
+ # Summary
277
+ print("\n" + "="*70)
278
+ print("πŸ“Š FINAL RESULTS")
279
+ print("="*70)
280
+ print(f"Initial Accuracy: {initial_acc:.1f}%")
281
+ print(f"Final Accuracy: {final_acc:.1f}%")
282
+ print(f"Improvement: {final_acc - initial_acc:+.1f}%")
283
+ print()
284
+ print("πŸ“ˆ Training Progress:")
285
+ for step, acc in eval_callback.accuracies:
286
+ print(f" Step {step}: {acc:.1f}%")
287
+ print("="*70)
288
+
289
+ # Save to Hub
290
+ print(f"\nπŸ“¦ Pushing to Hub: {OUTPUT_MODEL}")
291
+ trainer.model.push_to_hub(OUTPUT_MODEL)
292
+ tokenizer.push_to_hub(OUTPUT_MODEL)
293
+ print(f"βœ… Model pushed to: https://huggingface.co/{OUTPUT_MODEL}")
294
+
295
+ if __name__ == "__main__":
296
+ main()