#!/usr/bin/env python3 """ GRPO + RLVR Training for Simple Arithmetic - v3 (Minimal) Task: 2-digit addition and subtraction Base Model: Qwen/Qwen3-0.6B-Base Minimal version - no callbacks, no extra features """ import os import re import random import torch from datasets import Dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl import GRPOConfig, GRPOTrainer # ============================================================================ # CONFIG # ============================================================================ BASE_MODEL = "Qwen/Qwen3-0.6B-Base" OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v3" MAX_STEPS = 20 NUM_SAMPLES = 500 # ============================================================================ # DATA GENERATION # ============================================================================ def generate_arithmetic_samples(n_samples): """Generate simple arithmetic problems""" samples = [] for _ in range(n_samples): op = random.choice(['+', '-']) if op == '+': a = random.randint(10, 99) b = random.randint(10, 99) answer = a + b problem = f"{a} + {b} = ?" else: a = random.randint(20, 99) b = random.randint(10, a-1) answer = a - b problem = f"{a} - {b} = ?" samples.append({ 'prompt': f"Solve: {problem}\nAnswer:", 'answer': str(answer), }) return samples # ============================================================================ # REWARD FUNCTION (Improved) # ============================================================================ def extract_answer(text): """ Extract the final answer from model output. Priority: 1. Number in $$...$$ LaTeX blocks (last one) 2. Number after "Answer:" pattern 3. Last standalone number (fallback) """ # Try to find numbers in $$...$$ blocks first latex_blocks = re.findall(r'\$\$(.*?)\$\$', text, re.DOTALL) if latex_blocks: # Get the last LaTeX block and extract number last_block = latex_blocks[-1] numbers = re.findall(r'-?\d+\.?\d*', last_block) if numbers: return numbers[-1].strip() # Try to find number after "Answer:" pattern answer_match = re.search(r'Answer:\s*(-?\d+\.?\d*)', text, re.IGNORECASE) if answer_match: return answer_match.group(1).strip() # Fallback: last number in text numbers = re.findall(r'-?\d+\.?\d*', text) if numbers: return numbers[-1].strip() return "" def reward_func(completions, prompts=None, **kwargs): """ Reward function for arithmetic with improved extraction. """ # Try multiple column names for ground truth answers = None for key in ['answer', 'ground_truth', 'solution', 'label']: if key in kwargs and kwargs[key] is not None: answers = kwargs[key] break if answers is None: print("āš ļø WARNING: No ground truth found in kwargs!") print(f" Available keys: {list(kwargs.keys())}") return [0.0] * len(completions) rewards = [] for i, (completion, truth) in enumerate(zip(completions, answers)): # Handle list format (conversational) if isinstance(completion, list): text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion]) else: text = str(completion) # Extract answer using improved method predicted = extract_answer(text) # Exact match reward is_correct = predicted == str(truth).strip() rewards.append(1.0 if is_correct else 0.0) # Debug first 2 samples per batch if i < 2: status = "āœ…" if is_correct else "āŒ" print(f" [{i+1}] {status} Truth={truth} | Pred={predicted} | Text={text[:60]}...") return rewards # ============================================================================ # MAIN TRAINING # ============================================================================ def main(): print("="*70) print("šŸ”¢ GRPO + RLVR Arithmetic Training - v3 (Minimal)") print("="*70) print(f"Base Model: {BASE_MODEL}") print(f"Output: {OUTPUT_MODEL}") print(f"Steps: {MAX_STEPS}") print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") print("="*70 + "\n") # Load model and tokenizer print("šŸ“¦ Loading model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) # Ensure pad token is set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f" Set pad_token to eos_token: {tokenizer.eos_token}") model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) print(" Model loaded successfully!\n") # Generate training data print("šŸ“Š Generating training data...") train_samples = generate_arithmetic_samples(NUM_SAMPLES) train_dataset = Dataset.from_list(train_samples) print(f"āœ… {len(train_dataset)} training samples\n") # GRPO Config is_cpu = not torch.cuda.is_available() print("šŸ“ Creating GRPO Config...") training_args = GRPOConfig( output_dir="./outputs", max_steps=MAX_STEPS, per_device_train_batch_size=2, num_generations=2, learning_rate=2e-4, beta=0.0, bf16=False, # Always False for CPU safety fp16=False, gradient_checkpointing=False, optim="adamw_torch", logging_steps=1, save_steps=MAX_STEPS, push_to_hub=False, report_to="none", ) print(" GRPO Config created!\n") # Create trainer print("šŸ”§ Creating GRPO Trainer...") trainer = GRPOTrainer( model=model, args=training_args, train_dataset=train_dataset, reward_funcs=[reward_func], ) print(" Trainer created!\n") # Train print("šŸš€ Starting GRPO Training...") print("="*70 + "\n") trainer.train() print("\n" + "="*70) print("āœ… Training complete!") print("="*70) # Save to Hub print(f"\nšŸ“¦ Pushing to Hub: {OUTPUT_MODEL}") trainer.model.push_to_hub(OUTPUT_MODEL) tokenizer.push_to_hub(OUTPUT_MODEL) print(f"āœ… Model pushed to: https://huggingface.co/{OUTPUT_MODEL}") if __name__ == "__main__": main()