Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| GRPO + RLVR Training for Simple Arithmetic | |
| Task: 2-digit addition and subtraction | |
| Base Model: Qwen/Qwen3-0.6B-Base | |
| """ | |
| import os | |
| import re | |
| import random | |
| import torch | |
| from datasets import Dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import GRPOConfig, GRPOTrainer | |
| # ============================================================================ | |
| # CONFIG | |
| # ============================================================================ | |
| BASE_MODEL = "Qwen/Qwen3-0.6B-Base" | |
| OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic" | |
| MAX_STEPS = 20 # Reduced for CPU testing | |
| NUM_SAMPLES = 500 # Training samples | |
| EVAL_SAMPLES = 20 # For baseline test | |
| # ============================================================================ | |
| # DATA GENERATION | |
| # ============================================================================ | |
| def generate_arithmetic_samples(n_samples): | |
| """Generate simple arithmetic problems""" | |
| samples = [] | |
| for _ in range(n_samples): | |
| # Random operation | |
| op = random.choice(['+', '-']) | |
| if op == '+': | |
| a = random.randint(10, 99) | |
| b = random.randint(10, 99) | |
| answer = a + b | |
| problem = f"{a} + {b} = ?" | |
| else: | |
| a = random.randint(20, 99) | |
| b = random.randint(10, a-1) # Ensure positive result | |
| answer = a - b | |
| problem = f"{a} - {b} = ?" | |
| samples.append({ | |
| 'prompt': f"Solve this arithmetic problem. Give only the answer as a number.\n\n{problem}", | |
| 'answer': str(answer) | |
| }) | |
| return samples | |
| # ============================================================================ | |
| # REWARD FUNCTION | |
| # ============================================================================ | |
| def reward_func(completions, prompts, **kwargs): | |
| """ | |
| Reward function for arithmetic. | |
| Extract the last number from completion, compare to ground truth. | |
| """ | |
| answers = kwargs.get('answer', kwargs.get('ground_truth', None)) | |
| if answers is None: | |
| return [0.0] * len(completions) | |
| rewards = [] | |
| for completion, truth in zip(completions, answers): | |
| # Handle list format (conversational) | |
| if isinstance(completion, list): | |
| text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion]) | |
| else: | |
| text = str(completion) | |
| # Extract the last number | |
| numbers = re.findall(r'-?\d+\.?\d*', text) | |
| if numbers: | |
| predicted = numbers[-1].strip() | |
| else: | |
| predicted = "" | |
| # Exact match reward | |
| if predicted == str(truth).strip(): | |
| rewards.append(1.0) | |
| else: | |
| rewards.append(0.0) | |
| return rewards | |
| # ============================================================================ | |
| # BASELINE TEST | |
| # ============================================================================ | |
| def test_base_model(model, tokenizer, n_samples=20): | |
| """Test base model performance before training""" | |
| print("\n" + "="*70) | |
| print("π TESTING BASE MODEL PERFORMANCE") | |
| print("="*70) | |
| test_samples = generate_arithmetic_samples(n_samples) | |
| correct = 0 | |
| model.eval() | |
| with torch.no_grad(): | |
| for i, sample in enumerate(test_samples): | |
| inputs = tokenizer(sample['prompt'], return_tensors='pt') | |
| # Handle device placement | |
| if hasattr(model, 'device') and model.device is not None: | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=20, | |
| do_sample=False, | |
| temperature=1.0 | |
| ) | |
| # Safely decode response | |
| input_ids = inputs.get('input_ids') | |
| if input_ids is not None and hasattr(input_ids, 'shape'): | |
| response = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True) | |
| else: | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract answer | |
| numbers = re.findall(r'-?\d+\.?\d*', response) | |
| predicted = numbers[-1].strip() if numbers else "" | |
| truth = sample['answer'].strip() | |
| is_correct = predicted == truth | |
| if is_correct: | |
| correct += 1 | |
| status = "β " if is_correct else "β" | |
| print(f"[{i+1}] {status} {sample['prompt'].split('= ?')[0].split()[-1]} = {truth} | Predicted: {predicted} | Response: {response[:50]}...") | |
| accuracy = correct / n_samples * 100 | |
| print(f"\nπ Base Model Accuracy: {accuracy:.1f}% ({correct}/{n_samples})") | |
| if accuracy > 90: | |
| print("β οΈ WARNING: Base model already performs well! Task may be too easy.") | |
| elif accuracy < 50: | |
| print("β Good! Base model performs poorly. Room for improvement!") | |
| print("="*70 + "\n") | |
| return accuracy | |
| # ============================================================================ | |
| # MAIN TRAINING | |
| # ============================================================================ | |
| def main(): | |
| print("="*70) | |
| print("π’ GRPO + RLVR Arithmetic Training") | |
| print("="*70) | |
| print(f"Base Model: {BASE_MODEL}") | |
| print(f"Output: {OUTPUT_MODEL}") | |
| print(f"Steps: {MAX_STEPS}") | |
| print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") | |
| print("="*70 + "\n") | |
| # Load model and tokenizer | |
| print("π¦ Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| ) | |
| # Test base model first | |
| baseline_accuracy = test_base_model(model, tokenizer, n_samples=EVAL_SAMPLES) | |
| # Generate training data | |
| print("π Generating training data...") | |
| train_samples = generate_arithmetic_samples(NUM_SAMPLES) | |
| train_dataset = Dataset.from_list(train_samples) | |
| print(f"β {len(train_dataset)} training samples\n") | |
| # GRPO Config | |
| is_cpu = not torch.cuda.is_available() | |
| training_args = GRPOConfig( | |
| output_dir="./outputs", | |
| max_steps=MAX_STEPS, | |
| per_device_train_batch_size=2, # Reduced for CPU | |
| num_generations=2, # Reduced for CPU (faster) | |
| learning_rate=2e-4, | |
| beta=0.0, # No KL penalty for this task | |
| bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(), | |
| fp16=False, | |
| gradient_checkpointing=not is_cpu, # Disable on CPU | |
| optim="adamw_torch" if is_cpu else "adamw_8bit", # Use standard optimizer on CPU | |
| logging_steps=1, | |
| save_steps=MAX_STEPS, # Save at end | |
| push_to_hub=False, # We'll push manually | |
| report_to="none", | |
| ) | |
| print("π Starting GRPO Training...") | |
| print(f"Baseline accuracy: {baseline_accuracy:.1f}%\n") | |
| # Train | |
| trainer = GRPOTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| reward_funcs=[reward_func], # Note: plural 'reward_funcs' as list | |
| ) | |
| trainer.train() | |
| print("\nβ Training complete!") | |
| # Save to Hub | |
| print(f"\nπ¦ Pushing to Hub: {OUTPUT_MODEL}") | |
| trainer.model.push_to_hub(OUTPUT_MODEL) | |
| tokenizer.push_to_hub(OUTPUT_MODEL) | |
| print(f"β Model pushed to: https://huggingface.co/{OUTPUT_MODEL}") | |
| print("="*70) | |
| if __name__ == "__main__": | |
| main() | |