rlm-arithmetic-training / train_arithmetic_v3.py
mindchain's picture
Upload train_arithmetic_v3.py with huggingface_hub
786e916 verified
#!/usr/bin/env python3
"""
GRPO + RLVR Training for Simple Arithmetic - v3 (Minimal)
Task: 2-digit addition and subtraction
Base Model: Qwen/Qwen3-0.6B-Base
Minimal version - no callbacks, no extra features
"""
import os
import re
import random
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
# ============================================================================
# CONFIG
# ============================================================================
BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v3"
MAX_STEPS = 20
NUM_SAMPLES = 500
# ============================================================================
# DATA GENERATION
# ============================================================================
def generate_arithmetic_samples(n_samples):
"""Generate simple arithmetic problems"""
samples = []
for _ in range(n_samples):
op = random.choice(['+', '-'])
if op == '+':
a = random.randint(10, 99)
b = random.randint(10, 99)
answer = a + b
problem = f"{a} + {b} = ?"
else:
a = random.randint(20, 99)
b = random.randint(10, a-1)
answer = a - b
problem = f"{a} - {b} = ?"
samples.append({
'prompt': f"Solve: {problem}\nAnswer:",
'answer': str(answer),
})
return samples
# ============================================================================
# REWARD FUNCTION (Improved)
# ============================================================================
def extract_answer(text):
"""
Extract the final answer from model output.
Priority:
1. Number in $$...$$ LaTeX blocks (last one)
2. Number after "Answer:" pattern
3. Last standalone number (fallback)
"""
# Try to find numbers in $$...$$ blocks first
latex_blocks = re.findall(r'\$\$(.*?)\$\$', text, re.DOTALL)
if latex_blocks:
# Get the last LaTeX block and extract number
last_block = latex_blocks[-1]
numbers = re.findall(r'-?\d+\.?\d*', last_block)
if numbers:
return numbers[-1].strip()
# Try to find number after "Answer:" pattern
answer_match = re.search(r'Answer:\s*(-?\d+\.?\d*)', text, re.IGNORECASE)
if answer_match:
return answer_match.group(1).strip()
# Fallback: last number in text
numbers = re.findall(r'-?\d+\.?\d*', text)
if numbers:
return numbers[-1].strip()
return ""
def reward_func(completions, prompts=None, **kwargs):
"""
Reward function for arithmetic with improved extraction.
"""
# Try multiple column names for ground truth
answers = None
for key in ['answer', 'ground_truth', 'solution', 'label']:
if key in kwargs and kwargs[key] is not None:
answers = kwargs[key]
break
if answers is None:
print("⚠️ WARNING: No ground truth found in kwargs!")
print(f" Available keys: {list(kwargs.keys())}")
return [0.0] * len(completions)
rewards = []
for i, (completion, truth) in enumerate(zip(completions, answers)):
# Handle list format (conversational)
if isinstance(completion, list):
text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
else:
text = str(completion)
# Extract answer using improved method
predicted = extract_answer(text)
# Exact match reward
is_correct = predicted == str(truth).strip()
rewards.append(1.0 if is_correct else 0.0)
# Debug first 2 samples per batch
if i < 2:
status = "βœ…" if is_correct else "❌"
print(f" [{i+1}] {status} Truth={truth} | Pred={predicted} | Text={text[:60]}...")
return rewards
# ============================================================================
# MAIN TRAINING
# ============================================================================
def main():
print("="*70)
print("πŸ”’ GRPO + RLVR Arithmetic Training - v3 (Minimal)")
print("="*70)
print(f"Base Model: {BASE_MODEL}")
print(f"Output: {OUTPUT_MODEL}")
print(f"Steps: {MAX_STEPS}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print("="*70 + "\n")
# Load model and tokenizer
print("πŸ“¦ Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# Ensure pad token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f" Set pad_token to eos_token: {tokenizer.eos_token}")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
print(" Model loaded successfully!\n")
# Generate training data
print("πŸ“Š Generating training data...")
train_samples = generate_arithmetic_samples(NUM_SAMPLES)
train_dataset = Dataset.from_list(train_samples)
print(f"βœ… {len(train_dataset)} training samples\n")
# GRPO Config
is_cpu = not torch.cuda.is_available()
print("πŸ“ Creating GRPO Config...")
training_args = GRPOConfig(
output_dir="./outputs",
max_steps=MAX_STEPS,
per_device_train_batch_size=2,
num_generations=2,
learning_rate=2e-4,
beta=0.0,
bf16=False, # Always False for CPU safety
fp16=False,
gradient_checkpointing=False,
optim="adamw_torch",
logging_steps=1,
save_steps=MAX_STEPS,
push_to_hub=False,
report_to="none",
)
print(" GRPO Config created!\n")
# Create trainer
print("πŸ”§ Creating GRPO Trainer...")
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
reward_funcs=[reward_func],
)
print(" Trainer created!\n")
# Train
print("πŸš€ Starting GRPO Training...")
print("="*70 + "\n")
trainer.train()
print("\n" + "="*70)
print("βœ… Training complete!")
print("="*70)
# Save to Hub
print(f"\nπŸ“¦ Pushing to Hub: {OUTPUT_MODEL}")
trainer.model.push_to_hub(OUTPUT_MODEL)
tokenizer.push_to_hub(OUTPUT_MODEL)
print(f"βœ… Model pushed to: https://huggingface.co/{OUTPUT_MODEL}")
if __name__ == "__main__":
main()