File size: 6,664 Bytes
b70b3eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786e916
b70b3eb
 
786e916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b70b3eb
 
786e916
b70b3eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786e916
 
b70b3eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
#!/usr/bin/env python3
"""
GRPO + RLVR Training for Simple Arithmetic - v3 (Minimal)
Task: 2-digit addition and subtraction
Base Model: Qwen/Qwen3-0.6B-Base

Minimal version - no callbacks, no extra features
"""

import os
import re
import random
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer

# ============================================================================
# CONFIG
# ============================================================================

BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v3"
MAX_STEPS = 20
NUM_SAMPLES = 500

# ============================================================================
# DATA GENERATION
# ============================================================================

def generate_arithmetic_samples(n_samples):
    """Generate simple arithmetic problems"""
    samples = []
    for _ in range(n_samples):
        op = random.choice(['+', '-'])
        
        if op == '+':
            a = random.randint(10, 99)
            b = random.randint(10, 99)
            answer = a + b
            problem = f"{a} + {b} = ?"
        else:
            a = random.randint(20, 99)
            b = random.randint(10, a-1)
            answer = a - b
            problem = f"{a} - {b} = ?"
        
        samples.append({
            'prompt': f"Solve: {problem}\nAnswer:",
            'answer': str(answer),
        })
    
    return samples

# ============================================================================
# REWARD FUNCTION (Improved)
# ============================================================================

def extract_answer(text):
    """
    Extract the final answer from model output.
    Priority:
    1. Number in $$...$$ LaTeX blocks (last one)
    2. Number after "Answer:" pattern
    3. Last standalone number (fallback)
    """
    # Try to find numbers in $$...$$ blocks first
    latex_blocks = re.findall(r'\$\$(.*?)\$\$', text, re.DOTALL)
    if latex_blocks:
        # Get the last LaTeX block and extract number
        last_block = latex_blocks[-1]
        numbers = re.findall(r'-?\d+\.?\d*', last_block)
        if numbers:
            return numbers[-1].strip()
    
    # Try to find number after "Answer:" pattern
    answer_match = re.search(r'Answer:\s*(-?\d+\.?\d*)', text, re.IGNORECASE)
    if answer_match:
        return answer_match.group(1).strip()
    
    # Fallback: last number in text
    numbers = re.findall(r'-?\d+\.?\d*', text)
    if numbers:
        return numbers[-1].strip()
    
    return ""


def reward_func(completions, prompts=None, **kwargs):
    """
    Reward function for arithmetic with improved extraction.
    """
    # Try multiple column names for ground truth
    answers = None
    for key in ['answer', 'ground_truth', 'solution', 'label']:
        if key in kwargs and kwargs[key] is not None:
            answers = kwargs[key]
            break
    
    if answers is None:
        print("⚠️  WARNING: No ground truth found in kwargs!")
        print(f"   Available keys: {list(kwargs.keys())}")
        return [0.0] * len(completions)
    
    rewards = []
    
    for i, (completion, truth) in enumerate(zip(completions, answers)):
        # Handle list format (conversational)
        if isinstance(completion, list):
            text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
        else:
            text = str(completion)
        
        # Extract answer using improved method
        predicted = extract_answer(text)
        
        # Exact match reward
        is_correct = predicted == str(truth).strip()
        rewards.append(1.0 if is_correct else 0.0)
        
        # Debug first 2 samples per batch
        if i < 2:
            status = "βœ…" if is_correct else "❌"
            print(f"   [{i+1}] {status} Truth={truth} | Pred={predicted} | Text={text[:60]}...")
    
    return rewards

# ============================================================================
# MAIN TRAINING
# ============================================================================

def main():
    print("="*70)
    print("πŸ”’ GRPO + RLVR Arithmetic Training - v3 (Minimal)")
    print("="*70)
    print(f"Base Model: {BASE_MODEL}")
    print(f"Output: {OUTPUT_MODEL}")
    print(f"Steps: {MAX_STEPS}")
    print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
    print("="*70 + "\n")
    
    # Load model and tokenizer
    print("πŸ“¦ Loading model and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    
    # Ensure pad token is set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        print(f"   Set pad_token to eos_token: {tokenizer.eos_token}")
    
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    )
    
    print("   Model loaded successfully!\n")
    
    # Generate training data
    print("πŸ“Š Generating training data...")
    train_samples = generate_arithmetic_samples(NUM_SAMPLES)
    train_dataset = Dataset.from_list(train_samples)
    print(f"βœ… {len(train_dataset)} training samples\n")
    
    # GRPO Config
    is_cpu = not torch.cuda.is_available()
    
    print("πŸ“ Creating GRPO Config...")
    training_args = GRPOConfig(
        output_dir="./outputs",
        max_steps=MAX_STEPS,
        per_device_train_batch_size=2,
        num_generations=2,
        learning_rate=2e-4,
        beta=0.0,
        bf16=False,  # Always False for CPU safety
        fp16=False,
        gradient_checkpointing=False,
        optim="adamw_torch",
        logging_steps=1,
        save_steps=MAX_STEPS,
        push_to_hub=False,
        report_to="none",
    )
    print("   GRPO Config created!\n")
    
    # Create trainer
    print("πŸ”§ Creating GRPO Trainer...")
    trainer = GRPOTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        reward_funcs=[reward_func],
    )
    print("   Trainer created!\n")
    
    # Train
    print("πŸš€ Starting GRPO Training...")
    print("="*70 + "\n")
    
    trainer.train()
    
    print("\n" + "="*70)
    print("βœ… Training complete!")
    print("="*70)
    
    # Save to Hub
    print(f"\nπŸ“¦ Pushing to Hub: {OUTPUT_MODEL}")
    trainer.model.push_to_hub(OUTPUT_MODEL)
    tokenizer.push_to_hub(OUTPUT_MODEL)
    print(f"βœ… Model pushed to: https://huggingface.co/{OUTPUT_MODEL}")

if __name__ == "__main__":
    main()