mindchain commited on
Commit
8b46b16
Β·
verified Β·
1 Parent(s): 5977b18

Upload train_arithmetic_v7_clean.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_arithmetic_v7_clean.py +236 -0
train_arithmetic_v7_clean.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO + RLVR Training Script v7 - Clean & Simple
4
+
5
+ Just the essentials:
6
+ - 4-bit Quantization (BitsAndBytes)
7
+ - LoRA Adapters (QLoRA)
8
+ - Standard PyTorch training
9
+
10
+ No IPEX, no OpenVINO, no torch.compile - just reliable training.
11
+ """
12
+
13
+ import os
14
+ import random
15
+ import re
16
+ import torch
17
+ from transformers import (
18
+ AutoModelForCausalLM,
19
+ AutoTokenizer,
20
+ BitsAndBytesConfig,
21
+ )
22
+ from peft import LoraConfig, get_peft_model, TaskType
23
+ from trl import GRPOConfig, GRPOTrainer
24
+ from datasets import Dataset
25
+
26
+ # ============================================================================
27
+ # CONFIG
28
+ # ============================================================================
29
+
30
+ BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
31
+ OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v7"
32
+ MAX_STEPS = 50
33
+ NUM_SAMPLES = 500
34
+ BATCH_SIZE = 4
35
+ NUM_GENERATIONS = 4
36
+
37
+ # LoRA Config
38
+ LORA_R = 16
39
+ LORA_ALPHA = 32
40
+ LORA_DROPOUT = 0.05
41
+
42
+ # ============================================================================
43
+ # DATA GENERATION
44
+ # ============================================================================
45
+
46
+ def generate_arithmetic_samples(n_samples):
47
+ """Generate simple arithmetic problems"""
48
+ samples = []
49
+ for _ in range(n_samples):
50
+ op = random.choice(['+', '-'])
51
+ if op == '+':
52
+ a = random.randint(1, 50)
53
+ b = random.randint(1, 50)
54
+ answer = a + b
55
+ else:
56
+ a = random.randint(10, 100)
57
+ b = random.randint(1, a)
58
+ answer = a - b
59
+
60
+ prompt = f"Calculate: {a} {op} {b} = "
61
+ samples.append({
62
+ "prompt": prompt,
63
+ "answer": str(answer)
64
+ })
65
+ return samples
66
+
67
+ # ============================================================================
68
+ # REWARD FUNCTION
69
+ # ============================================================================
70
+
71
+ def extract_number(text):
72
+ """Extract number from text, handling LaTeX format"""
73
+ # Priority 1: Numbers in $$...$$ blocks (LaTeX)
74
+ latex_match = re.search(r'\$\$(\d+(?:\.\d+)?)\$\$', text)
75
+ if latex_match:
76
+ return latex_match.group(1)
77
+
78
+ # Priority 2: Numbers after "Answer:"
79
+ answer_match = re.search(r'Answer:\s*(\d+(?:\.\d+)?)', text, re.IGNORECASE)
80
+ if answer_match:
81
+ return answer_match.group(1)
82
+
83
+ # Priority 3: Last number in text
84
+ numbers = re.findall(r'\d+(?:\.\d+)?', text)
85
+ if numbers:
86
+ return numbers[-1]
87
+
88
+ return None
89
+
90
+ def reward_func(completions, prompts, **kwargs):
91
+ """Reward function for arithmetic tasks"""
92
+ # Get ground truth
93
+ ground_truth = kwargs.get('ground_truth', kwargs.get('answer', kwargs.get('solution', None)))
94
+ if ground_truth is None:
95
+ return [0.0] * len(completions)
96
+
97
+ rewards = []
98
+ for completion, truth in zip(completions, ground_truth):
99
+ # Handle list format (conversational)
100
+ if isinstance(completion, list):
101
+ text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
102
+ else:
103
+ text = str(completion)
104
+
105
+ # Extract predicted number
106
+ predicted = extract_number(text)
107
+
108
+ # Calculate reward
109
+ if predicted is not None and str(predicted) == str(truth):
110
+ rewards.append(1.0)
111
+ else:
112
+ rewards.append(0.0)
113
+
114
+ return rewards
115
+
116
+ # ============================================================================
117
+ # MAIN
118
+ # ============================================================================
119
+
120
+ def main():
121
+ print("=" * 70)
122
+ print("πŸš€ GRPO + RLVR v7 - Clean & Simple")
123
+ print("=" * 70)
124
+ print(f"Base Model: {BASE_MODEL}")
125
+ print(f"Output: {OUTPUT_MODEL}")
126
+ print(f"Steps: {MAX_STEPS}")
127
+ print("=" * 70)
128
+
129
+ # Check CPU threads
130
+ print(f"\nπŸ“Š CPU Threads: {os.cpu_count()}")
131
+
132
+ # Show optimizations
133
+ print("\nπŸ“Š Optimizations:")
134
+ print(f" 4-bit Quantization: βœ…")
135
+ print(f" LoRA Adapters: βœ… (R={LORA_R})")
136
+ print(f" IPEX: ❌ (skipped for stability)")
137
+ print(f" OpenVINO: ❌ (skipped for stability)")
138
+ print(f" torch.compile: ❌ (skipped for stability)")
139
+ print("=" * 70)
140
+
141
+ # Load tokenizer
142
+ print("\nπŸ“¦ Loading tokenizer...")
143
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
144
+ tokenizer.pad_token = tokenizer.eos_token
145
+
146
+ # Load model with quantization
147
+ print("\nπŸ“¦ Loading model with 4-bit quantization...")
148
+ quantization_config = BitsAndBytesConfig(
149
+ load_in_4bit=True,
150
+ bnb_4bit_quant_type="nf4",
151
+ bnb_4bit_compute_dtype=torch.float16,
152
+ bnb_4bit_use_double_quant=True,
153
+ )
154
+
155
+ model = AutoModelForCausalLM.from_pretrained(
156
+ BASE_MODEL,
157
+ quantization_config=quantization_config,
158
+ device_map="auto",
159
+ trust_remote_code=True,
160
+ )
161
+
162
+ # Add LoRA adapters
163
+ print("\nπŸ“¦ Adding LoRA adapters...")
164
+ lora_config = LoraConfig(
165
+ task_type=TaskType.CAUSAL_LM,
166
+ r=LORA_R,
167
+ lora_alpha=LORA_ALPHA,
168
+ lora_dropout=LORA_DROPOUT,
169
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
170
+ bias="none",
171
+ )
172
+ model = get_peft_model(model, lora_config)
173
+ model.print_trainable_parameters()
174
+
175
+ # Generate training data
176
+ print(f"\nπŸ“Š Generating {NUM_SAMPLES} training samples...")
177
+ samples = generate_arithmetic_samples(NUM_SAMPLES)
178
+
179
+ # Create dataset
180
+ dataset = Dataset.from_list([
181
+ {
182
+ "prompt": s["prompt"],
183
+ "ground_truth": s["answer"],
184
+ }
185
+ for s in samples
186
+ ])
187
+
188
+ # GRPO Config
189
+ training_args = GRPOConfig(
190
+ output_dir="./results",
191
+ num_train_epochs=1,
192
+ max_steps=MAX_STEPS,
193
+ per_device_train_batch_size=BATCH_SIZE,
194
+ gradient_accumulation_steps=2,
195
+ num_generations=NUM_GENERATIONS,
196
+ learning_rate=5e-5,
197
+ bf16=False, # CPU doesn't support BF16
198
+ fp16=False, # 4-bit quantization is enough
199
+ gradient_checkpointing=True,
200
+ optim="paged_adamw_8bit",
201
+ logging_steps=1,
202
+ save_steps=25,
203
+ save_total_limit=2,
204
+ report_to="none",
205
+ remove_unused_columns=False,
206
+ )
207
+
208
+ # Create trainer
209
+ print("\nπŸ“¦ Creating GRPO trainer...")
210
+ trainer = GRPOTrainer(
211
+ model=model,
212
+ args=training_args,
213
+ train_dataset=dataset,
214
+ processing_class=tokenizer,
215
+ reward_funcs=[reward_func],
216
+ )
217
+
218
+ # Train
219
+ print("\nπŸš€ Starting GRPO Training...")
220
+ trainer.train()
221
+
222
+ # Save LoRA adapters
223
+ print("\nπŸ“¦ Saving LoRA adapters...")
224
+ model.save_pretrained(OUTPUT_MODEL)
225
+ tokenizer.save_pretrained(OUTPUT_MODEL)
226
+
227
+ # Push to Hub
228
+ print(f"\nπŸ“¦ Pushing to Hub: {OUTPUT_MODEL}")
229
+ model.push_to_hub(OUTPUT_MODEL, token=os.environ.get("HF_TOKEN"))
230
+ tokenizer.push_to_hub(OUTPUT_MODEL, token=os.environ.get("HF_TOKEN"))
231
+
232
+ print("\nβœ… Training complete!")
233
+ print(f"Output: {OUTPUT_MODEL}")
234
+
235
+ if __name__ == "__main__":
236
+ main()