#!/usr/bin/env python3 """ GRPO + RLVR Training for Simple Arithmetic Task: 2-digit addition and subtraction Base Model: Qwen/Qwen3-0.6B-Base """ import os import re import random import torch from datasets import Dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl import GRPOConfig, GRPOTrainer # ============================================================================ # CONFIG # ============================================================================ BASE_MODEL = "Qwen/Qwen3-0.6B-Base" OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic" MAX_STEPS = 20 # Reduced for CPU testing NUM_SAMPLES = 500 # Training samples EVAL_SAMPLES = 20 # For baseline test # ============================================================================ # DATA GENERATION # ============================================================================ def generate_arithmetic_samples(n_samples): """Generate simple arithmetic problems""" samples = [] for _ in range(n_samples): # Random operation op = random.choice(['+', '-']) if op == '+': a = random.randint(10, 99) b = random.randint(10, 99) answer = a + b problem = f"{a} + {b} = ?" else: a = random.randint(20, 99) b = random.randint(10, a-1) # Ensure positive result answer = a - b problem = f"{a} - {b} = ?" samples.append({ 'prompt': f"Solve this arithmetic problem. Give only the answer as a number.\n\n{problem}", 'answer': str(answer) }) return samples # ============================================================================ # REWARD FUNCTION # ============================================================================ def reward_func(completions, prompts, **kwargs): """ Reward function for arithmetic. Extract the last number from completion, compare to ground truth. """ answers = kwargs.get('answer', kwargs.get('ground_truth', None)) if answers is None: return [0.0] * len(completions) rewards = [] for completion, truth in zip(completions, answers): # Handle list format (conversational) if isinstance(completion, list): text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion]) else: text = str(completion) # Extract the last number numbers = re.findall(r'-?\d+\.?\d*', text) if numbers: predicted = numbers[-1].strip() else: predicted = "" # Exact match reward if predicted == str(truth).strip(): rewards.append(1.0) else: rewards.append(0.0) return rewards # ============================================================================ # BASELINE TEST # ============================================================================ def test_base_model(model, tokenizer, n_samples=20): """Test base model performance before training""" print("\n" + "="*70) print("šŸ“Š TESTING BASE MODEL PERFORMANCE") print("="*70) test_samples = generate_arithmetic_samples(n_samples) correct = 0 model.eval() with torch.no_grad(): for i, sample in enumerate(test_samples): inputs = tokenizer(sample['prompt'], return_tensors='pt') # Handle device placement if hasattr(model, 'device') and model.device is not None: inputs = {k: v.to(model.device) for k, v in inputs.items()} outputs = model.generate( **inputs, max_new_tokens=20, do_sample=False, temperature=1.0 ) # Safely decode response input_ids = inputs.get('input_ids') if input_ids is not None and hasattr(input_ids, 'shape'): response = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True) else: response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract answer numbers = re.findall(r'-?\d+\.?\d*', response) predicted = numbers[-1].strip() if numbers else "" truth = sample['answer'].strip() is_correct = predicted == truth if is_correct: correct += 1 status = "āœ…" if is_correct else "āŒ" print(f"[{i+1}] {status} {sample['prompt'].split('= ?')[0].split()[-1]} = {truth} | Predicted: {predicted} | Response: {response[:50]}...") accuracy = correct / n_samples * 100 print(f"\nšŸ“Š Base Model Accuracy: {accuracy:.1f}% ({correct}/{n_samples})") if accuracy > 90: print("āš ļø WARNING: Base model already performs well! Task may be too easy.") elif accuracy < 50: print("āœ… Good! Base model performs poorly. Room for improvement!") print("="*70 + "\n") return accuracy # ============================================================================ # MAIN TRAINING # ============================================================================ def main(): print("="*70) print("šŸ”¢ GRPO + RLVR Arithmetic Training") print("="*70) print(f"Base Model: {BASE_MODEL}") print(f"Output: {OUTPUT_MODEL}") print(f"Steps: {MAX_STEPS}") print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") print("="*70 + "\n") # Load model and tokenizer print("šŸ“¦ Loading model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) # Test base model first baseline_accuracy = test_base_model(model, tokenizer, n_samples=EVAL_SAMPLES) # Generate training data print("šŸ“Š Generating training data...") train_samples = generate_arithmetic_samples(NUM_SAMPLES) train_dataset = Dataset.from_list(train_samples) print(f"āœ… {len(train_dataset)} training samples\n") # GRPO Config is_cpu = not torch.cuda.is_available() training_args = GRPOConfig( output_dir="./outputs", max_steps=MAX_STEPS, per_device_train_batch_size=2, # Reduced for CPU num_generations=2, # Reduced for CPU (faster) learning_rate=2e-4, beta=0.0, # No KL penalty for this task bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(), fp16=False, gradient_checkpointing=not is_cpu, # Disable on CPU optim="adamw_torch" if is_cpu else "adamw_8bit", # Use standard optimizer on CPU logging_steps=1, save_steps=MAX_STEPS, # Save at end push_to_hub=False, # We'll push manually report_to="none", ) print("šŸš€ Starting GRPO Training...") print(f"Baseline accuracy: {baseline_accuracy:.1f}%\n") # Train trainer = GRPOTrainer( model=model, args=training_args, train_dataset=train_dataset, reward_funcs=[reward_func], # Note: plural 'reward_funcs' as list ) trainer.train() print("\nāœ… Training complete!") # Save to Hub print(f"\nšŸ“¦ Pushing to Hub: {OUTPUT_MODEL}") trainer.model.push_to_hub(OUTPUT_MODEL) tokenizer.push_to_hub(OUTPUT_MODEL) print(f"āœ… Model pushed to: https://huggingface.co/{OUTPUT_MODEL}") print("="*70) if __name__ == "__main__": main()