File size: 7,687 Bytes
306c5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95008ad
306c5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0168a3e
306c5f0
 
0168a3e
306c5f0
 
 
0168a3e
 
 
 
 
 
306c5f0
 
 
 
 
 
0168a3e
 
 
 
 
 
 
306c5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0168a3e
306c5f0
 
 
95008ad
 
306c5f0
 
 
 
0168a3e
 
306c5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
61cc0c7
306c5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
#!/usr/bin/env python3
"""
GRPO + RLVR Training for Simple Arithmetic
Task: 2-digit addition and subtraction
Base Model: Qwen/Qwen3-0.6B-Base
"""

import os
import re
import random
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer

# ============================================================================
# CONFIG
# ============================================================================

BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic"
MAX_STEPS = 20  # Reduced for CPU testing
NUM_SAMPLES = 500  # Training samples
EVAL_SAMPLES = 20   # For baseline test

# ============================================================================
# DATA GENERATION
# ============================================================================

def generate_arithmetic_samples(n_samples):
    """Generate simple arithmetic problems"""
    samples = []
    for _ in range(n_samples):
        # Random operation
        op = random.choice(['+', '-'])
        
        if op == '+':
            a = random.randint(10, 99)
            b = random.randint(10, 99)
            answer = a + b
            problem = f"{a} + {b} = ?"
        else:
            a = random.randint(20, 99)
            b = random.randint(10, a-1)  # Ensure positive result
            answer = a - b
            problem = f"{a} - {b} = ?"
        
        samples.append({
            'prompt': f"Solve this arithmetic problem. Give only the answer as a number.\n\n{problem}",
            'answer': str(answer)
        })
    
    return samples

# ============================================================================
# REWARD FUNCTION
# ============================================================================

def reward_func(completions, prompts, **kwargs):
    """
    Reward function for arithmetic.
    Extract the last number from completion, compare to ground truth.
    """
    answers = kwargs.get('answer', kwargs.get('ground_truth', None))
    if answers is None:
        return [0.0] * len(completions)
    
    rewards = []
    for completion, truth in zip(completions, answers):
        # Handle list format (conversational)
        if isinstance(completion, list):
            text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
        else:
            text = str(completion)
        
        # Extract the last number
        numbers = re.findall(r'-?\d+\.?\d*', text)
        if numbers:
            predicted = numbers[-1].strip()
        else:
            predicted = ""
        
        # Exact match reward
        if predicted == str(truth).strip():
            rewards.append(1.0)
        else:
            rewards.append(0.0)
    
    return rewards

# ============================================================================
# BASELINE TEST
# ============================================================================

def test_base_model(model, tokenizer, n_samples=20):
    """Test base model performance before training"""
    print("\n" + "="*70)
    print("πŸ“Š TESTING BASE MODEL PERFORMANCE")
    print("="*70)

    test_samples = generate_arithmetic_samples(n_samples)
    correct = 0

    model.eval()
    with torch.no_grad():
        for i, sample in enumerate(test_samples):
            inputs = tokenizer(sample['prompt'], return_tensors='pt')

            # Handle device placement
            if hasattr(model, 'device') and model.device is not None:
                inputs = {k: v.to(model.device) for k, v in inputs.items()}

            outputs = model.generate(
                **inputs,
                max_new_tokens=20,
                do_sample=False,
                temperature=1.0
            )

            # Safely decode response
            input_ids = inputs.get('input_ids')
            if input_ids is not None and hasattr(input_ids, 'shape'):
                response = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
            else:
                response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract answer
            numbers = re.findall(r'-?\d+\.?\d*', response)
            predicted = numbers[-1].strip() if numbers else ""
            truth = sample['answer'].strip()
            
            is_correct = predicted == truth
            if is_correct:
                correct += 1
            
            status = "βœ…" if is_correct else "❌"
            print(f"[{i+1}] {status} {sample['prompt'].split('= ?')[0].split()[-1]} = {truth} | Predicted: {predicted} | Response: {response[:50]}...")
    
    accuracy = correct / n_samples * 100
    print(f"\nπŸ“Š Base Model Accuracy: {accuracy:.1f}% ({correct}/{n_samples})")
    
    if accuracy > 90:
        print("⚠️  WARNING: Base model already performs well! Task may be too easy.")
    elif accuracy < 50:
        print("βœ… Good! Base model performs poorly. Room for improvement!")
    
    print("="*70 + "\n")
    
    return accuracy

# ============================================================================
# MAIN TRAINING
# ============================================================================

def main():
    print("="*70)
    print("πŸ”’ GRPO + RLVR Arithmetic Training")
    print("="*70)
    print(f"Base Model: {BASE_MODEL}")
    print(f"Output: {OUTPUT_MODEL}")
    print(f"Steps: {MAX_STEPS}")
    print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
    print("="*70 + "\n")
    
    # Load model and tokenizer
    print("πŸ“¦ Loading model and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    )
    
    # Test base model first
    baseline_accuracy = test_base_model(model, tokenizer, n_samples=EVAL_SAMPLES)
    
    # Generate training data
    print("πŸ“Š Generating training data...")
    train_samples = generate_arithmetic_samples(NUM_SAMPLES)
    train_dataset = Dataset.from_list(train_samples)
    print(f"βœ… {len(train_dataset)} training samples\n")
    
    # GRPO Config
    is_cpu = not torch.cuda.is_available()
    training_args = GRPOConfig(
        output_dir="./outputs",
        max_steps=MAX_STEPS,
        per_device_train_batch_size=2,  # Reduced for CPU
        num_generations=2,  # Reduced for CPU (faster)
        learning_rate=2e-4,
        beta=0.0,  # No KL penalty for this task
        bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
        fp16=False,
        gradient_checkpointing=not is_cpu,  # Disable on CPU
        optim="adamw_torch" if is_cpu else "adamw_8bit",  # Use standard optimizer on CPU
        logging_steps=1,
        save_steps=MAX_STEPS,  # Save at end
        push_to_hub=False,  # We'll push manually
        report_to="none",
    )
    
    print("πŸš€ Starting GRPO Training...")
    print(f"Baseline accuracy: {baseline_accuracy:.1f}%\n")
    
    # Train
    trainer = GRPOTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        reward_funcs=[reward_func],  # Note: plural 'reward_funcs' as list
    )
    
    trainer.train()
    
    print("\nβœ… Training complete!")
    
    # Save to Hub
    print(f"\nπŸ“¦ Pushing to Hub: {OUTPUT_MODEL}")
    trainer.model.push_to_hub(OUTPUT_MODEL)
    tokenizer.push_to_hub(OUTPUT_MODEL)
    
    print(f"βœ… Model pushed to: https://huggingface.co/{OUTPUT_MODEL}")
    print("="*70)

if __name__ == "__main__":
    main()