Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| GRPO + RLVR Training for Simple Arithmetic - v3 (Minimal) | |
| Task: 2-digit addition and subtraction | |
| Base Model: Qwen/Qwen3-0.6B-Base | |
| Minimal version - no callbacks, no extra features | |
| """ | |
| import os | |
| import re | |
| import random | |
| import torch | |
| from datasets import Dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import GRPOConfig, GRPOTrainer | |
| # ============================================================================ | |
| # CONFIG | |
| # ============================================================================ | |
| BASE_MODEL = "Qwen/Qwen3-0.6B-Base" | |
| OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v3" | |
| MAX_STEPS = 20 | |
| NUM_SAMPLES = 500 | |
| # ============================================================================ | |
| # DATA GENERATION | |
| # ============================================================================ | |
| def generate_arithmetic_samples(n_samples): | |
| """Generate simple arithmetic problems""" | |
| samples = [] | |
| for _ in range(n_samples): | |
| op = random.choice(['+', '-']) | |
| if op == '+': | |
| a = random.randint(10, 99) | |
| b = random.randint(10, 99) | |
| answer = a + b | |
| problem = f"{a} + {b} = ?" | |
| else: | |
| a = random.randint(20, 99) | |
| b = random.randint(10, a-1) | |
| answer = a - b | |
| problem = f"{a} - {b} = ?" | |
| samples.append({ | |
| 'prompt': f"Solve: {problem}\nAnswer:", | |
| 'answer': str(answer), | |
| }) | |
| return samples | |
| # ============================================================================ | |
| # REWARD FUNCTION (Improved) | |
| # ============================================================================ | |
| def extract_answer(text): | |
| """ | |
| Extract the final answer from model output. | |
| Priority: | |
| 1. Number in $$...$$ LaTeX blocks (last one) | |
| 2. Number after "Answer:" pattern | |
| 3. Last standalone number (fallback) | |
| """ | |
| # Try to find numbers in $$...$$ blocks first | |
| latex_blocks = re.findall(r'\$\$(.*?)\$\$', text, re.DOTALL) | |
| if latex_blocks: | |
| # Get the last LaTeX block and extract number | |
| last_block = latex_blocks[-1] | |
| numbers = re.findall(r'-?\d+\.?\d*', last_block) | |
| if numbers: | |
| return numbers[-1].strip() | |
| # Try to find number after "Answer:" pattern | |
| answer_match = re.search(r'Answer:\s*(-?\d+\.?\d*)', text, re.IGNORECASE) | |
| if answer_match: | |
| return answer_match.group(1).strip() | |
| # Fallback: last number in text | |
| numbers = re.findall(r'-?\d+\.?\d*', text) | |
| if numbers: | |
| return numbers[-1].strip() | |
| return "" | |
| def reward_func(completions, prompts=None, **kwargs): | |
| """ | |
| Reward function for arithmetic with improved extraction. | |
| """ | |
| # Try multiple column names for ground truth | |
| answers = None | |
| for key in ['answer', 'ground_truth', 'solution', 'label']: | |
| if key in kwargs and kwargs[key] is not None: | |
| answers = kwargs[key] | |
| break | |
| if answers is None: | |
| print("β οΈ WARNING: No ground truth found in kwargs!") | |
| print(f" Available keys: {list(kwargs.keys())}") | |
| return [0.0] * len(completions) | |
| rewards = [] | |
| for i, (completion, truth) in enumerate(zip(completions, answers)): | |
| # Handle list format (conversational) | |
| if isinstance(completion, list): | |
| text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion]) | |
| else: | |
| text = str(completion) | |
| # Extract answer using improved method | |
| predicted = extract_answer(text) | |
| # Exact match reward | |
| is_correct = predicted == str(truth).strip() | |
| rewards.append(1.0 if is_correct else 0.0) | |
| # Debug first 2 samples per batch | |
| if i < 2: | |
| status = "β " if is_correct else "β" | |
| print(f" [{i+1}] {status} Truth={truth} | Pred={predicted} | Text={text[:60]}...") | |
| return rewards | |
| # ============================================================================ | |
| # MAIN TRAINING | |
| # ============================================================================ | |
| def main(): | |
| print("="*70) | |
| print("π’ GRPO + RLVR Arithmetic Training - v3 (Minimal)") | |
| print("="*70) | |
| print(f"Base Model: {BASE_MODEL}") | |
| print(f"Output: {OUTPUT_MODEL}") | |
| print(f"Steps: {MAX_STEPS}") | |
| print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") | |
| print("="*70 + "\n") | |
| # Load model and tokenizer | |
| print("π¦ Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| # Ensure pad token is set | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print(f" Set pad_token to eos_token: {tokenizer.eos_token}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| ) | |
| print(" Model loaded successfully!\n") | |
| # Generate training data | |
| print("π Generating training data...") | |
| train_samples = generate_arithmetic_samples(NUM_SAMPLES) | |
| train_dataset = Dataset.from_list(train_samples) | |
| print(f"β {len(train_dataset)} training samples\n") | |
| # GRPO Config | |
| is_cpu = not torch.cuda.is_available() | |
| print("π Creating GRPO Config...") | |
| training_args = GRPOConfig( | |
| output_dir="./outputs", | |
| max_steps=MAX_STEPS, | |
| per_device_train_batch_size=2, | |
| num_generations=2, | |
| learning_rate=2e-4, | |
| beta=0.0, | |
| bf16=False, # Always False for CPU safety | |
| fp16=False, | |
| gradient_checkpointing=False, | |
| optim="adamw_torch", | |
| logging_steps=1, | |
| save_steps=MAX_STEPS, | |
| push_to_hub=False, | |
| report_to="none", | |
| ) | |
| print(" GRPO Config created!\n") | |
| # Create trainer | |
| print("π§ Creating GRPO Trainer...") | |
| trainer = GRPOTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| reward_funcs=[reward_func], | |
| ) | |
| print(" Trainer created!\n") | |
| # Train | |
| print("π Starting GRPO Training...") | |
| print("="*70 + "\n") | |
| trainer.train() | |
| print("\n" + "="*70) | |
| print("β Training complete!") | |
| print("="*70) | |
| # Save to Hub | |
| print(f"\nπ¦ Pushing to Hub: {OUTPUT_MODEL}") | |
| trainer.model.push_to_hub(OUTPUT_MODEL) | |
| tokenizer.push_to_hub(OUTPUT_MODEL) | |
| print(f"β Model pushed to: https://huggingface.co/{OUTPUT_MODEL}") | |
| if __name__ == "__main__": | |
| main() | |