rlm-arithmetic-training / train_arithmetic.py
mindchain's picture
Fix hang: remove use_cpu parameter, reduce generations to 2, batch to 2, steps to 20
95008ad verified
#!/usr/bin/env python3
"""
GRPO + RLVR Training for Simple Arithmetic
Task: 2-digit addition and subtraction
Base Model: Qwen/Qwen3-0.6B-Base
"""
import os
import re
import random
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
# ============================================================================
# CONFIG
# ============================================================================
BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic"
MAX_STEPS = 20 # Reduced for CPU testing
NUM_SAMPLES = 500 # Training samples
EVAL_SAMPLES = 20 # For baseline test
# ============================================================================
# DATA GENERATION
# ============================================================================
def generate_arithmetic_samples(n_samples):
"""Generate simple arithmetic problems"""
samples = []
for _ in range(n_samples):
# Random operation
op = random.choice(['+', '-'])
if op == '+':
a = random.randint(10, 99)
b = random.randint(10, 99)
answer = a + b
problem = f"{a} + {b} = ?"
else:
a = random.randint(20, 99)
b = random.randint(10, a-1) # Ensure positive result
answer = a - b
problem = f"{a} - {b} = ?"
samples.append({
'prompt': f"Solve this arithmetic problem. Give only the answer as a number.\n\n{problem}",
'answer': str(answer)
})
return samples
# ============================================================================
# REWARD FUNCTION
# ============================================================================
def reward_func(completions, prompts, **kwargs):
"""
Reward function for arithmetic.
Extract the last number from completion, compare to ground truth.
"""
answers = kwargs.get('answer', kwargs.get('ground_truth', None))
if answers is None:
return [0.0] * len(completions)
rewards = []
for completion, truth in zip(completions, answers):
# Handle list format (conversational)
if isinstance(completion, list):
text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
else:
text = str(completion)
# Extract the last number
numbers = re.findall(r'-?\d+\.?\d*', text)
if numbers:
predicted = numbers[-1].strip()
else:
predicted = ""
# Exact match reward
if predicted == str(truth).strip():
rewards.append(1.0)
else:
rewards.append(0.0)
return rewards
# ============================================================================
# BASELINE TEST
# ============================================================================
def test_base_model(model, tokenizer, n_samples=20):
"""Test base model performance before training"""
print("\n" + "="*70)
print("πŸ“Š TESTING BASE MODEL PERFORMANCE")
print("="*70)
test_samples = generate_arithmetic_samples(n_samples)
correct = 0
model.eval()
with torch.no_grad():
for i, sample in enumerate(test_samples):
inputs = tokenizer(sample['prompt'], return_tensors='pt')
# Handle device placement
if hasattr(model, 'device') and model.device is not None:
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_new_tokens=20,
do_sample=False,
temperature=1.0
)
# Safely decode response
input_ids = inputs.get('input_ids')
if input_ids is not None and hasattr(input_ids, 'shape'):
response = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
else:
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract answer
numbers = re.findall(r'-?\d+\.?\d*', response)
predicted = numbers[-1].strip() if numbers else ""
truth = sample['answer'].strip()
is_correct = predicted == truth
if is_correct:
correct += 1
status = "βœ…" if is_correct else "❌"
print(f"[{i+1}] {status} {sample['prompt'].split('= ?')[0].split()[-1]} = {truth} | Predicted: {predicted} | Response: {response[:50]}...")
accuracy = correct / n_samples * 100
print(f"\nπŸ“Š Base Model Accuracy: {accuracy:.1f}% ({correct}/{n_samples})")
if accuracy > 90:
print("⚠️ WARNING: Base model already performs well! Task may be too easy.")
elif accuracy < 50:
print("βœ… Good! Base model performs poorly. Room for improvement!")
print("="*70 + "\n")
return accuracy
# ============================================================================
# MAIN TRAINING
# ============================================================================
def main():
print("="*70)
print("πŸ”’ GRPO + RLVR Arithmetic Training")
print("="*70)
print(f"Base Model: {BASE_MODEL}")
print(f"Output: {OUTPUT_MODEL}")
print(f"Steps: {MAX_STEPS}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print("="*70 + "\n")
# Load model and tokenizer
print("πŸ“¦ Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
# Test base model first
baseline_accuracy = test_base_model(model, tokenizer, n_samples=EVAL_SAMPLES)
# Generate training data
print("πŸ“Š Generating training data...")
train_samples = generate_arithmetic_samples(NUM_SAMPLES)
train_dataset = Dataset.from_list(train_samples)
print(f"βœ… {len(train_dataset)} training samples\n")
# GRPO Config
is_cpu = not torch.cuda.is_available()
training_args = GRPOConfig(
output_dir="./outputs",
max_steps=MAX_STEPS,
per_device_train_batch_size=2, # Reduced for CPU
num_generations=2, # Reduced for CPU (faster)
learning_rate=2e-4,
beta=0.0, # No KL penalty for this task
bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
fp16=False,
gradient_checkpointing=not is_cpu, # Disable on CPU
optim="adamw_torch" if is_cpu else "adamw_8bit", # Use standard optimizer on CPU
logging_steps=1,
save_steps=MAX_STEPS, # Save at end
push_to_hub=False, # We'll push manually
report_to="none",
)
print("πŸš€ Starting GRPO Training...")
print(f"Baseline accuracy: {baseline_accuracy:.1f}%\n")
# Train
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
reward_funcs=[reward_func], # Note: plural 'reward_funcs' as list
)
trainer.train()
print("\nβœ… Training complete!")
# Save to Hub
print(f"\nπŸ“¦ Pushing to Hub: {OUTPUT_MODEL}")
trainer.model.push_to_hub(OUTPUT_MODEL)
tokenizer.push_to_hub(OUTPUT_MODEL)
print(f"βœ… Model pushed to: https://huggingface.co/{OUTPUT_MODEL}")
print("="*70)
if __name__ == "__main__":
main()