mindchain commited on
Commit
6cbcf96
Β·
verified Β·
1 Parent(s): 786e916

Upload train_arithmetic_v5_ultimate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_arithmetic_v5_ultimate.py +320 -0
train_arithmetic_v5_ultimate.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO + RLVR Training - v5 (Ultimate CPU Optimized + Quantized)
4
+ Optimized for HF Spaces CPU with 4-bit quantization
5
+
6
+ Features:
7
+ - 4-bit Quantization (BitsAndBytes) - faster inference
8
+ - LoRA Adapters (QLoRA) - efficient training
9
+ - Intel Extension for PyTorch (IPEX) - CPU optimization
10
+ - torch.compile() JIT compilation
11
+ - BetterTransformer (optimized attention)
12
+ - LaTeX-aware answer extraction
13
+ - All optimizations combined!
14
+ """
15
+
16
+ import os
17
+ import re
18
+ import random
19
+ import torch
20
+ from datasets import Dataset
21
+ from transformers import (
22
+ AutoModelForCausalLM,
23
+ AutoTokenizer,
24
+ BitsAndBytesConfig,
25
+ )
26
+ from peft import LoraConfig, get_peft_model, TaskType
27
+ from trl import GRPOConfig, GRPOTrainer
28
+
29
+ # ============================================================================
30
+ # OPTIMIZATION FLAGS
31
+ # ============================================================================
32
+
33
+ USE_IPEX = False
34
+ USE_COMPILE = hasattr(torch, 'compile')
35
+ USE_BETTER_TRANSFORMER = False
36
+ USE_QUANTIZATION = True # Enable 4-bit quantization
37
+
38
+ try:
39
+ import intel_extension_for_pytorch as ipex
40
+ USE_IPEX = True
41
+ print("βœ… IPEX available")
42
+ except ImportError:
43
+ print("⚠️ IPEX not available")
44
+
45
+ try:
46
+ from optimum.bettertransformer import BetterTransformer
47
+ USE_BETTER_TRANSFORMER = True
48
+ print("βœ… BetterTransformer available")
49
+ except ImportError:
50
+ print("⚠️ BetterTransformer not available")
51
+
52
+ # ============================================================================
53
+ # CONFIG
54
+ # ============================================================================
55
+
56
+ BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
57
+ OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v5-quantized"
58
+ MAX_STEPS = 50
59
+ NUM_SAMPLES = 500
60
+ BATCH_SIZE = 4 # Larger batch with quantization
61
+ NUM_GENERATIONS = 4 # More generations
62
+
63
+ # LoRA Config
64
+ LORA_R = 16
65
+ LORA_ALPHA = 32
66
+ LORA_DROPOUT = 0.05
67
+
68
+ # Quantization Config
69
+ USE_4BIT = True # Use 4-bit quantization
70
+
71
+ # ============================================================================
72
+ # DATA GENERATION
73
+ # ============================================================================
74
+
75
+ def generate_arithmetic_samples(n_samples):
76
+ """Generate simple arithmetic problems"""
77
+ samples = []
78
+ for _ in range(n_samples):
79
+ op = random.choice(['+', '-'])
80
+
81
+ if op == '+':
82
+ a = random.randint(10, 99)
83
+ b = random.randint(10, 99)
84
+ answer = a + b
85
+ problem = f"{a} + {b} = ?"
86
+ else:
87
+ a = random.randint(20, 99)
88
+ b = random.randint(10, a-1)
89
+ answer = a - b
90
+ problem = f"{a} - {b} = ?"
91
+
92
+ samples.append({
93
+ 'prompt': f"Solve: {problem}\nAnswer:",
94
+ 'answer': str(answer),
95
+ })
96
+
97
+ return samples
98
+
99
+ # ============================================================================
100
+ # REWARD FUNCTION (LaTeX-aware)
101
+ # ============================================================================
102
+
103
+ def extract_answer(text):
104
+ """
105
+ Extract the final answer from model output.
106
+ Priority:
107
+ 1. Number in $$...$$ LaTeX blocks
108
+ 2. Number after "Answer:" pattern
109
+ 3. Last standalone number (fallback)
110
+ """
111
+ # Try LaTeX blocks first
112
+ latex_blocks = re.findall(r'\$\$(.*?)\$\$', text, re.DOTALL)
113
+ if latex_blocks:
114
+ last_block = latex_blocks[-1]
115
+ numbers = re.findall(r'-?\d+\.?\d*', last_block)
116
+ if numbers:
117
+ return numbers[-1].strip()
118
+
119
+ # Try "Answer:" pattern
120
+ answer_match = re.search(r'Answer:\s*(-?\d+\.?\d*)', text, re.IGNORECASE)
121
+ if answer_match:
122
+ return answer_match.group(1).strip()
123
+
124
+ # Fallback: last number
125
+ numbers = re.findall(r'-?\d+\.?\d*', text)
126
+ if numbers:
127
+ return numbers[-1].strip()
128
+
129
+ return ""
130
+
131
+
132
+ def reward_func(completions, prompts=None, **kwargs):
133
+ """Reward function with LaTeX-aware extraction."""
134
+ answers = None
135
+ for key in ['answer', 'ground_truth', 'solution', 'label']:
136
+ if key in kwargs and kwargs[key] is not None:
137
+ answers = kwargs[key]
138
+ break
139
+
140
+ if answers is None:
141
+ return [0.0] * len(completions)
142
+
143
+ rewards = []
144
+ for i, (completion, truth) in enumerate(zip(completions, answers)):
145
+ if isinstance(completion, list):
146
+ text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
147
+ else:
148
+ text = str(completion)
149
+
150
+ predicted = extract_answer(text)
151
+ is_correct = predicted == str(truth).strip()
152
+ rewards.append(1.0 if is_correct else 0.0)
153
+
154
+ if i < 2:
155
+ status = "βœ…" if is_correct else "❌"
156
+ print(f" [{i+1}] {status} Truth={truth} | Pred={predicted}")
157
+
158
+ return rewards
159
+
160
+ # ============================================================================
161
+ # MAIN TRAINING
162
+ # ============================================================================
163
+
164
+ def main():
165
+ print("="*70)
166
+ print("πŸš€ GRPO + RLVR v5 - Ultimate CPU Optimized + Quantized")
167
+ print("="*70)
168
+ print(f"Base Model: {BASE_MODEL}")
169
+ print(f"Output: {OUTPUT_MODEL}")
170
+ print(f"Steps: {MAX_STEPS}")
171
+ print("="*70)
172
+
173
+ # Print optimization status
174
+ print("\nπŸ“Š Optimizations:")
175
+ print(f" 4-bit Quantization: {'βœ…' if USE_4BIT else '❌'}")
176
+ print(f" LoRA Adapters: βœ… (R={LORA_R})")
177
+ print(f" IPEX: {'βœ…' if USE_IPEX else '❌'}")
178
+ print(f" torch.compile: {'βœ…' if USE_COMPILE else '❌'}")
179
+ print(f" BetterTransformer: {'βœ…' if USE_BETTER_TRANSFORMER else '❌'}")
180
+ print("="*70 + "\n")
181
+
182
+ # CPU optimization
183
+ torch.set_num_threads(os.cpu_count() or 4)
184
+ print(f"πŸ“Š CPU Threads: {torch.get_num_threads()}\n")
185
+
186
+ # Load tokenizer
187
+ print("πŸ“¦ Loading tokenizer...")
188
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
189
+ if tokenizer.pad_token is None:
190
+ tokenizer.pad_token = tokenizer.eos_token
191
+
192
+ # Quantization config
193
+ if USE_4BIT:
194
+ print("\nπŸ“¦ Loading model with 4-bit quantization...")
195
+ quantization_config = BitsAndBytesConfig(
196
+ load_in_4bit=True,
197
+ bnb_4bit_quant_type="nf4",
198
+ bnb_4bit_compute_dtype=torch.float32, # CPU uses float32
199
+ bnb_4bit_use_double_quant=True,
200
+ )
201
+
202
+ try:
203
+ model = AutoModelForCausalLM.from_pretrained(
204
+ BASE_MODEL,
205
+ quantization_config=quantization_config,
206
+ device_map="auto",
207
+ )
208
+ print(" Model loaded in 4-bit!")
209
+ except Exception as e:
210
+ print(f" ⚠️ 4-bit failed: {e}")
211
+ print(" Falling back to FP32...")
212
+ model = AutoModelForCausalLM.from_pretrained(
213
+ BASE_MODEL,
214
+ torch_dtype=torch.float32,
215
+ )
216
+ else:
217
+ print("\nπŸ“¦ Loading model in FP32...")
218
+ model = AutoModelForCausalLM.from_pretrained(
219
+ BASE_MODEL,
220
+ torch_dtype=torch.float32,
221
+ )
222
+
223
+ # Add LoRA adapters
224
+ print("\nπŸ”§ Adding LoRA adapters...")
225
+ lora_config = LoraConfig(
226
+ r=LORA_R,
227
+ lora_alpha=LORA_ALPHA,
228
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
229
+ lora_dropout=LORA_DROPOUT,
230
+ bias="none",
231
+ task_type=TaskType.CAUSAL_LM,
232
+ )
233
+ model = get_peft_model(model, lora_config)
234
+ model.print_trainable_parameters()
235
+
236
+ # Apply IPEX
237
+ if USE_IPEX:
238
+ print("\nπŸ”§ Applying IPEX...")
239
+ try:
240
+ # Note: IPEX with PEFT models may need special handling
241
+ model = ipex.optimize(model, dtype=torch.float32)
242
+ print(" IPEX applied!")
243
+ except Exception as e:
244
+ print(f" ⚠️ IPEX failed: {e}")
245
+
246
+ # Apply BetterTransformer
247
+ if USE_BETTER_TRANSFORMER:
248
+ print("\nπŸ”§ Applying BetterTransformer...")
249
+ try:
250
+ model = BetterTransformer.transform(model)
251
+ print(" BetterTransformer applied!")
252
+ except Exception as e:
253
+ print(f" ⚠️ BetterTransformer failed: {e}")
254
+
255
+ # Generate training data
256
+ print("\nπŸ“Š Generating training data...")
257
+ train_samples = generate_arithmetic_samples(NUM_SAMPLES)
258
+ train_dataset = Dataset.from_list(train_samples)
259
+ print(f"βœ… {len(train_dataset)} training samples\n")
260
+
261
+ # GRPO Config
262
+ training_args = GRPOConfig(
263
+ output_dir="./outputs",
264
+ max_steps=MAX_STEPS,
265
+ per_device_train_batch_size=BATCH_SIZE,
266
+ num_generations=NUM_GENERATIONS,
267
+ learning_rate=2e-4,
268
+ beta=0.0,
269
+ bf16=False,
270
+ fp16=False,
271
+ gradient_checkpointing=False,
272
+ optim="adamw_torch",
273
+ logging_steps=1,
274
+ save_steps=MAX_STEPS,
275
+ push_to_hub=False,
276
+ report_to="none",
277
+ dataloader_num_workers=0,
278
+ dataloader_pin_memory=False,
279
+ )
280
+
281
+ print("πŸš€ Starting GRPO Training...")
282
+ print("="*70 + "\n")
283
+
284
+ # Create trainer
285
+ trainer = GRPOTrainer(
286
+ model=model,
287
+ args=training_args,
288
+ train_dataset=train_dataset,
289
+ reward_funcs=[reward_func],
290
+ )
291
+
292
+ # Apply torch.compile
293
+ if USE_COMPILE:
294
+ print("πŸ”§ Applying torch.compile()...")
295
+ try:
296
+ trainer.model = torch.compile(trainer.model)
297
+ print(" torch.compile() applied!\n")
298
+ except Exception as e:
299
+ print(f" ⚠️ torch.compile() failed: {e}\n")
300
+
301
+ # Train
302
+ trainer.train()
303
+
304
+ print("\n" + "="*70)
305
+ print("βœ… Training complete!")
306
+ print("="*70)
307
+
308
+ # Save LoRA adapters
309
+ print(f"\nπŸ“¦ Saving LoRA adapters to: {OUTPUT_MODEL}")
310
+ model.save_pretrained(OUTPUT_MODEL)
311
+ tokenizer.save_pretrained(OUTPUT_MODEL)
312
+
313
+ # Push to Hub
314
+ print(f"\nπŸ“¦ Pushing to Hub: {OUTPUT_MODEL}")
315
+ model.push_to_hub(OUTPUT_MODEL)
316
+ tokenizer.push_to_hub(OUTPUT_MODEL)
317
+ print(f"βœ… Model pushed to: https://huggingface.co/{OUTPUT_MODEL}")
318
+
319
+ if __name__ == "__main__":
320
+ main()