Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| GRPO + RLVR Training Script v7 - Clean & Simple | |
| Just the essentials: | |
| - 4-bit Quantization (BitsAndBytes) | |
| - LoRA Adapters (QLoRA) | |
| - Standard PyTorch training | |
| No IPEX, no OpenVINO, no torch.compile - just reliable training. | |
| """ | |
| import os | |
| import random | |
| import re | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| ) | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| from trl import GRPOConfig, GRPOTrainer | |
| from datasets import Dataset | |
| # ============================================================================ | |
| # CONFIG | |
| # ============================================================================ | |
| BASE_MODEL = "Qwen/Qwen3-0.6B-Base" | |
| OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v7" | |
| MAX_STEPS = 50 | |
| NUM_SAMPLES = 500 | |
| BATCH_SIZE = 4 | |
| NUM_GENERATIONS = 4 | |
| # LoRA Config | |
| LORA_R = 16 | |
| LORA_ALPHA = 32 | |
| LORA_DROPOUT = 0.05 | |
| # ============================================================================ | |
| # DATA GENERATION | |
| # ============================================================================ | |
| def generate_arithmetic_samples(n_samples): | |
| """Generate simple arithmetic problems""" | |
| samples = [] | |
| for _ in range(n_samples): | |
| op = random.choice(['+', '-']) | |
| if op == '+': | |
| a = random.randint(1, 50) | |
| b = random.randint(1, 50) | |
| answer = a + b | |
| else: | |
| a = random.randint(10, 100) | |
| b = random.randint(1, a) | |
| answer = a - b | |
| prompt = f"Calculate: {a} {op} {b} = " | |
| samples.append({ | |
| "prompt": prompt, | |
| "answer": str(answer) | |
| }) | |
| return samples | |
| # ============================================================================ | |
| # REWARD FUNCTION | |
| # ============================================================================ | |
| def extract_number(text): | |
| """Extract number from text, handling LaTeX format""" | |
| # Priority 1: Numbers in $$...$$ blocks (LaTeX) | |
| latex_match = re.search(r'\$\$(\d+(?:\.\d+)?)\$\$', text) | |
| if latex_match: | |
| return latex_match.group(1) | |
| # Priority 2: Numbers after "Answer:" | |
| answer_match = re.search(r'Answer:\s*(\d+(?:\.\d+)?)', text, re.IGNORECASE) | |
| if answer_match: | |
| return answer_match.group(1) | |
| # Priority 3: Last number in text | |
| numbers = re.findall(r'\d+(?:\.\d+)?', text) | |
| if numbers: | |
| return numbers[-1] | |
| return None | |
| def reward_func(completions, prompts, **kwargs): | |
| """Reward function for arithmetic tasks""" | |
| # Get ground truth | |
| ground_truth = kwargs.get('ground_truth', kwargs.get('answer', kwargs.get('solution', None))) | |
| if ground_truth is None: | |
| return [0.0] * len(completions) | |
| rewards = [] | |
| for completion, truth in zip(completions, ground_truth): | |
| # Handle list format (conversational) | |
| if isinstance(completion, list): | |
| text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion]) | |
| else: | |
| text = str(completion) | |
| # Extract predicted number | |
| predicted = extract_number(text) | |
| # Calculate reward | |
| if predicted is not None and str(predicted) == str(truth): | |
| rewards.append(1.0) | |
| else: | |
| rewards.append(0.0) | |
| return rewards | |
| # ============================================================================ | |
| # MAIN | |
| # ============================================================================ | |
| def main(): | |
| print("=" * 70) | |
| print("π GRPO + RLVR v7 - Clean & Simple") | |
| print("=" * 70) | |
| print(f"Base Model: {BASE_MODEL}") | |
| print(f"Output: {OUTPUT_MODEL}") | |
| print(f"Steps: {MAX_STEPS}") | |
| print("=" * 70) | |
| # Check CPU threads | |
| print(f"\nπ CPU Threads: {os.cpu_count()}") | |
| # Show optimizations | |
| print("\nπ Optimizations:") | |
| print(f" 4-bit Quantization: β ") | |
| print(f" LoRA Adapters: β (R={LORA_R})") | |
| print(f" IPEX: β (skipped for stability)") | |
| print(f" OpenVINO: β (skipped for stability)") | |
| print(f" torch.compile: β (skipped for stability)") | |
| print("=" * 70) | |
| # Load tokenizer | |
| print("\nπ¦ Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load model with quantization | |
| print("\nπ¦ Loading model with 4-bit quantization...") | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| # Add LoRA adapters | |
| print("\nπ¦ Adding LoRA adapters...") | |
| lora_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| r=LORA_R, | |
| lora_alpha=LORA_ALPHA, | |
| lora_dropout=LORA_DROPOUT, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| bias="none", | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| # Generate training data | |
| print(f"\nπ Generating {NUM_SAMPLES} training samples...") | |
| samples = generate_arithmetic_samples(NUM_SAMPLES) | |
| # Create dataset | |
| dataset = Dataset.from_list([ | |
| { | |
| "prompt": s["prompt"], | |
| "ground_truth": s["answer"], | |
| } | |
| for s in samples | |
| ]) | |
| # GRPO Config | |
| training_args = GRPOConfig( | |
| output_dir="./results", | |
| num_train_epochs=1, | |
| max_steps=MAX_STEPS, | |
| per_device_train_batch_size=BATCH_SIZE, | |
| gradient_accumulation_steps=2, | |
| num_generations=NUM_GENERATIONS, | |
| learning_rate=5e-5, | |
| bf16=False, # CPU doesn't support BF16 | |
| fp16=False, # 4-bit quantization is enough | |
| gradient_checkpointing=True, | |
| optim="paged_adamw_8bit", | |
| logging_steps=1, | |
| save_steps=25, | |
| save_total_limit=2, | |
| report_to="none", | |
| remove_unused_columns=False, | |
| ) | |
| # Create trainer | |
| print("\nπ¦ Creating GRPO trainer...") | |
| trainer = GRPOTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| processing_class=tokenizer, | |
| reward_funcs=[reward_func], | |
| ) | |
| # Train | |
| print("\nπ Starting GRPO Training...") | |
| trainer.train() | |
| # Save LoRA adapters | |
| print("\nπ¦ Saving LoRA adapters...") | |
| model.save_pretrained(OUTPUT_MODEL) | |
| tokenizer.save_pretrained(OUTPUT_MODEL) | |
| # Push to Hub | |
| print(f"\nπ¦ Pushing to Hub: {OUTPUT_MODEL}") | |
| model.push_to_hub(OUTPUT_MODEL, token=os.environ.get("HF_TOKEN")) | |
| tokenizer.push_to_hub(OUTPUT_MODEL, token=os.environ.get("HF_TOKEN")) | |
| print("\nβ Training complete!") | |
| print(f"Output: {OUTPUT_MODEL}") | |
| if __name__ == "__main__": | |
| main() | |