mindchain commited on
Commit
306c5f0
Β·
verified Β·
1 Parent(s): a80cc87

Upload train_arithmetic.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_arithmetic.py +214 -0
train_arithmetic.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO + RLVR Training for Simple Arithmetic
4
+ Task: 2-digit addition and subtraction
5
+ Base Model: Qwen/Qwen3-0.6B-Base
6
+ """
7
+
8
+ import os
9
+ import re
10
+ import random
11
+ import torch
12
+ from datasets import Dataset
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from trl import GRPOConfig, GRPOTrainer
15
+
16
+ # ============================================================================
17
+ # CONFIG
18
+ # ============================================================================
19
+
20
+ BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
21
+ OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic"
22
+ MAX_STEPS = 50
23
+ NUM_SAMPLES = 500 # Training samples
24
+ EVAL_SAMPLES = 20 # For baseline test
25
+
26
+ # ============================================================================
27
+ # DATA GENERATION
28
+ # ============================================================================
29
+
30
+ def generate_arithmetic_samples(n_samples):
31
+ """Generate simple arithmetic problems"""
32
+ samples = []
33
+ for _ in range(n_samples):
34
+ # Random operation
35
+ op = random.choice(['+', '-'])
36
+
37
+ if op == '+':
38
+ a = random.randint(10, 99)
39
+ b = random.randint(10, 99)
40
+ answer = a + b
41
+ problem = f"{a} + {b} = ?"
42
+ else:
43
+ a = random.randint(20, 99)
44
+ b = random.randint(10, a-1) # Ensure positive result
45
+ answer = a - b
46
+ problem = f"{a} - {b} = ?"
47
+
48
+ samples.append({
49
+ 'prompt': f"Solve this arithmetic problem. Give only the answer as a number.\n\n{problem}",
50
+ 'answer': str(answer)
51
+ })
52
+
53
+ return samples
54
+
55
+ # ============================================================================
56
+ # REWARD FUNCTION
57
+ # ============================================================================
58
+
59
+ def reward_func(completions, prompts, **kwargs):
60
+ """
61
+ Reward function for arithmetic.
62
+ Extract the last number from completion, compare to ground truth.
63
+ """
64
+ answers = kwargs.get('answer', kwargs.get('ground_truth', None))
65
+ if answers is None:
66
+ return [0.0] * len(completions)
67
+
68
+ rewards = []
69
+ for completion, truth in zip(completions, answers):
70
+ # Handle list format (conversational)
71
+ if isinstance(completion, list):
72
+ text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
73
+ else:
74
+ text = str(completion)
75
+
76
+ # Extract the last number
77
+ numbers = re.findall(r'-?\d+\.?\d*', text)
78
+ if numbers:
79
+ predicted = numbers[-1].strip()
80
+ else:
81
+ predicted = ""
82
+
83
+ # Exact match reward
84
+ if predicted == str(truth).strip():
85
+ rewards.append(1.0)
86
+ else:
87
+ rewards.append(0.0)
88
+
89
+ return rewards
90
+
91
+ # ============================================================================
92
+ # BASELINE TEST
93
+ # ============================================================================
94
+
95
+ def test_base_model(model, tokenizer, n_samples=20):
96
+ """Test base model performance before training"""
97
+ print("\n" + "="*70)
98
+ print("πŸ“Š TESTING BASE MODEL PERFORMANCE")
99
+ print("="*70)
100
+
101
+ test_samples = generate_arithmetic_samples(n_samples)
102
+ correct = 0
103
+
104
+ model.eval()
105
+ with torch.no_grad():
106
+ for i, sample in enumerate(test_samples):
107
+ inputs = tokenizer(sample['prompt'], return_tensors='pt').to(model.device)
108
+ outputs = model.generate(
109
+ **inputs,
110
+ max_new_tokens=20,
111
+ do_sample=False,
112
+ temperature=1.0
113
+ )
114
+ response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
115
+
116
+ # Extract answer
117
+ numbers = re.findall(r'-?\d+\.?\d*', response)
118
+ predicted = numbers[-1].strip() if numbers else ""
119
+ truth = sample['answer'].strip()
120
+
121
+ is_correct = predicted == truth
122
+ if is_correct:
123
+ correct += 1
124
+
125
+ status = "βœ…" if is_correct else "❌"
126
+ print(f"[{i+1}] {status} {sample['prompt'].split('= ?')[0].split()[-1]} = {truth} | Predicted: {predicted} | Response: {response[:50]}...")
127
+
128
+ accuracy = correct / n_samples * 100
129
+ print(f"\nπŸ“Š Base Model Accuracy: {accuracy:.1f}% ({correct}/{n_samples})")
130
+
131
+ if accuracy > 90:
132
+ print("⚠️ WARNING: Base model already performs well! Task may be too easy.")
133
+ elif accuracy < 50:
134
+ print("βœ… Good! Base model performs poorly. Room for improvement!")
135
+
136
+ print("="*70 + "\n")
137
+
138
+ return accuracy
139
+
140
+ # ============================================================================
141
+ # MAIN TRAINING
142
+ # ============================================================================
143
+
144
+ def main():
145
+ print("="*70)
146
+ print("πŸ”’ GRPO + RLVR Arithmetic Training")
147
+ print("="*70)
148
+ print(f"Base Model: {BASE_MODEL}")
149
+ print(f"Output: {OUTPUT_MODEL}")
150
+ print(f"Steps: {MAX_STEPS}")
151
+ print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
152
+ print("="*70 + "\n")
153
+
154
+ # Load model and tokenizer
155
+ print("πŸ“¦ Loading model and tokenizer...")
156
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
157
+ model = AutoModelForCausalLM.from_pretrained(
158
+ BASE_MODEL,
159
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
160
+ device_map="auto" if torch.cuda.is_available() else None
161
+ )
162
+
163
+ # Test base model first
164
+ baseline_accuracy = test_base_model(model, tokenizer, n_samples=EVAL_SAMPLES)
165
+
166
+ # Generate training data
167
+ print("πŸ“Š Generating training data...")
168
+ train_samples = generate_arithmetic_samples(NUM_SAMPLES)
169
+ train_dataset = Dataset.from_list(train_samples)
170
+ print(f"βœ… {len(train_dataset)} training samples\n")
171
+
172
+ # GRPO Config
173
+ training_args = GRPOConfig(
174
+ output_dir="./outputs",
175
+ max_steps=MAX_STEPS,
176
+ per_device_train_batch_size=4,
177
+ num_generations=4,
178
+ learning_rate=2e-4,
179
+ beta=0.0, # No KL penalty for this task
180
+ bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
181
+ fp16=False,
182
+ gradient_checkpointing=True,
183
+ optim="adamw_8bit",
184
+ logging_steps=1,
185
+ save_steps=MAX_STEPS, # Save at end
186
+ push_to_hub=False, # We'll push manually
187
+ report_to="none",
188
+ )
189
+
190
+ print("πŸš€ Starting GRPO Training...")
191
+ print(f"Baseline accuracy: {baseline_accuracy:.1f}%\n")
192
+
193
+ # Train
194
+ trainer = GRPOTrainer(
195
+ model=model,
196
+ args=training_args,
197
+ train_dataset=train_dataset,
198
+ reward_func=reward_func,
199
+ )
200
+
201
+ trainer.train()
202
+
203
+ print("\nβœ… Training complete!")
204
+
205
+ # Save to Hub
206
+ print(f"\nπŸ“¦ Pushing to Hub: {OUTPUT_MODEL}")
207
+ trainer.model.push_to_hub(OUTPUT_MODEL)
208
+ tokenizer.push_to_hub(OUTPUT_MODEL)
209
+
210
+ print(f"βœ… Model pushed to: https://huggingface.co/{OUTPUT_MODEL}")
211
+ print("="*70)
212
+
213
+ if __name__ == "__main__":
214
+ main()