mindchain commited on
Commit
695d68f
Β·
verified Β·
1 Parent(s): 440ca06

Upload train_arithmetic_v9_no_lora.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_arithmetic_v9_no_lora.py +223 -0
train_arithmetic_v9_no_lora.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO + RLVR Training Script v9 - NO LoRA
4
+
5
+ Test to isolate if LoRA is causing the stuck issue:
6
+ - 4-bit Quantization: YES
7
+ - LoRA: NO (testing without)
8
+ - Just basic GRPO training
9
+ """
10
+
11
+ import os
12
+ import random
13
+ import re
14
+ import torch
15
+ from transformers import (
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ BitsAndBytesConfig,
19
+ )
20
+ from trl import GRPOConfig, GRPOTrainer
21
+ from datasets import Dataset
22
+
23
+ # ============================================================================
24
+ # CONFIG
25
+ # ============================================================================
26
+
27
+ BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
28
+ OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v9"
29
+ MAX_STEPS = 50
30
+ NUM_SAMPLES = 500
31
+ BATCH_SIZE = 2 # Smaller batch without LoRA
32
+ NUM_GENERATIONS = 2 # Fewer generations
33
+
34
+ # ============================================================================
35
+ # DATA GENERATION
36
+ # ============================================================================
37
+
38
+ def generate_arithmetic_samples(n_samples):
39
+ """Generate simple arithmetic problems"""
40
+ samples = []
41
+ for _ in range(n_samples):
42
+ op = random.choice(['+', '-'])
43
+ if op == '+':
44
+ a = random.randint(1, 50)
45
+ b = random.randint(1, 50)
46
+ answer = a + b
47
+ else:
48
+ a = random.randint(10, 100)
49
+ b = random.randint(1, a)
50
+ answer = a - b
51
+
52
+ prompt = f"Calculate: {a} {op} {b} = "
53
+ samples.append({
54
+ "prompt": prompt,
55
+ "answer": str(answer)
56
+ })
57
+ return samples
58
+
59
+ # ============================================================================
60
+ # REWARD FUNCTION
61
+ # ============================================================================
62
+
63
+ def extract_number(text):
64
+ """Extract number from text, handling LaTeX format"""
65
+ # Priority 1: Numbers in $$...$$ blocks (LaTeX)
66
+ latex_match = re.search(r'\$\$(\d+(?:\.\d+)?)\$\$', text)
67
+ if latex_match:
68
+ return latex_match.group(1)
69
+
70
+ # Priority 2: Numbers after "Answer:"
71
+ answer_match = re.search(r'Answer:\s*(\d+(?:\.\d+)?)', text, re.IGNORECASE)
72
+ if answer_match:
73
+ return answer_match.group(1)
74
+
75
+ # Priority 3: Last number in text
76
+ numbers = re.findall(r'\d+(?:\.\d+)?', text)
77
+ if numbers:
78
+ return numbers[-1]
79
+
80
+ return None
81
+
82
+ def reward_func(completions, prompts, **kwargs):
83
+ """Reward function for arithmetic tasks"""
84
+ # Get ground truth
85
+ ground_truth = kwargs.get('ground_truth', kwargs.get('answer', kwargs.get('solution', None)))
86
+ if ground_truth is None:
87
+ return [0.0] * len(completions)
88
+
89
+ rewards = []
90
+ for completion, truth in zip(completions, ground_truth):
91
+ # Handle list format (conversational)
92
+ if isinstance(completion, list):
93
+ text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
94
+ else:
95
+ text = str(completion)
96
+
97
+ # Extract predicted number
98
+ predicted = extract_number(text)
99
+
100
+ # Calculate reward
101
+ if predicted is not None and str(predicted) == str(truth):
102
+ rewards.append(1.0)
103
+ else:
104
+ rewards.append(0.0)
105
+
106
+ return rewards
107
+
108
+ # ============================================================================
109
+ # MAIN
110
+ # ============================================================================
111
+
112
+ def main():
113
+ print("=" * 70)
114
+ print("πŸš€ GRPO + RLVR v9 - NO LoRA (Testing)")
115
+ print("=" * 70)
116
+ print(f"Base Model: {BASE_MODEL}")
117
+ print(f"Output: {OUTPUT_MODEL}")
118
+ print(f"Steps: {MAX_STEPS}")
119
+ print("=" * 70)
120
+
121
+ # Check CPU threads
122
+ print(f"\nπŸ“Š CPU Threads: {os.cpu_count()}")
123
+
124
+ # Show configuration
125
+ print("\nπŸ“Š Configuration:")
126
+ print(f" 4-bit Quantization: βœ…")
127
+ print(f" LoRA Adapters: ❌ (DISABLED FOR TESTING)")
128
+ print(f" Batch Size: {BATCH_SIZE}")
129
+ print(f" Generations: {NUM_GENERATIONS}")
130
+ print("=" * 70)
131
+
132
+ # Load tokenizer
133
+ print("\nπŸ“¦ Loading tokenizer...")
134
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
135
+ tokenizer.pad_token = tokenizer.eos_token
136
+ print("βœ… Tokenizer loaded!")
137
+
138
+ # Load model with quantization
139
+ print("\nπŸ“¦ Loading model with 4-bit quantization...")
140
+ quantization_config = BitsAndBytesConfig(
141
+ load_in_4bit=True,
142
+ bnb_4bit_quant_type="nf4",
143
+ bnb_4bit_compute_dtype=torch.float16,
144
+ bnb_4bit_use_double_quant=True,
145
+ )
146
+
147
+ model = AutoModelForCausalLM.from_pretrained(
148
+ BASE_MODEL,
149
+ quantization_config=quantization_config,
150
+ device_map="auto",
151
+ trust_remote_code=True,
152
+ )
153
+ print("βœ… Model loaded!")
154
+
155
+ # NO LoRA - skip this step entirely
156
+ print("\nπŸ“¦ Skipping LoRA (v9 test)...")
157
+ print("βœ… No LoRA adapters to add!")
158
+
159
+ # Generate training data
160
+ print(f"\nπŸ“Š Generating {NUM_SAMPLES} training samples...")
161
+ samples = generate_arithmetic_samples(NUM_SAMPLES)
162
+
163
+ # Create dataset
164
+ dataset = Dataset.from_list([
165
+ {
166
+ "prompt": s["prompt"],
167
+ "ground_truth": s["answer"],
168
+ }
169
+ for s in samples
170
+ ])
171
+ print("βœ… Training data generated!")
172
+
173
+ # GRPO Config
174
+ training_args = GRPOConfig(
175
+ output_dir="./results",
176
+ num_train_epochs=1,
177
+ max_steps=MAX_STEPS,
178
+ per_device_train_batch_size=BATCH_SIZE,
179
+ gradient_accumulation_steps=2,
180
+ num_generations=NUM_GENERATIONS,
181
+ learning_rate=5e-5,
182
+ bf16=False, # CPU doesn't support BF16
183
+ fp16=False, # 4-bit quantization is enough
184
+ gradient_checkpointing=True,
185
+ optim="paged_adamw_8bit",
186
+ logging_steps=1,
187
+ save_steps=25,
188
+ save_total_limit=2,
189
+ report_to="none",
190
+ remove_unused_columns=False,
191
+ )
192
+ print("βœ… GRPO config created!")
193
+
194
+ # Create trainer
195
+ print("\nπŸ“¦ Creating GRPO trainer...")
196
+ trainer = GRPOTrainer(
197
+ model=model,
198
+ args=training_args,
199
+ train_dataset=dataset,
200
+ processing_class=tokenizer,
201
+ reward_funcs=[reward_func],
202
+ )
203
+ print("βœ… GRPO trainer created!")
204
+
205
+ # Train
206
+ print("\nπŸš€ Starting GRPO Training...")
207
+ trainer.train()
208
+
209
+ # Save model
210
+ print("\nπŸ“¦ Saving model...")
211
+ model.save_pretrained(OUTPUT_MODEL)
212
+ tokenizer.save_pretrained(OUTPUT_MODEL)
213
+
214
+ # Push to Hub
215
+ print(f"\nπŸ“¦ Pushing to Hub: {OUTPUT_MODEL}")
216
+ model.push_to_hub(OUTPUT_MODEL, token=os.environ.get("HF_TOKEN"))
217
+ tokenizer.push_to_hub(OUTPUT_MODEL, token=os.environ.get("HF_TOKEN"))
218
+
219
+ print("\nβœ… Training complete!")
220
+ print(f"Output: {OUTPUT_MODEL}")
221
+
222
+ if __name__ == "__main__":
223
+ main()