#!/usr/bin/env python3 """ GRPO + RLVR Training Script v7 - Clean & Simple Just the essentials: - 4-bit Quantization (BitsAndBytes) - LoRA Adapters (QLoRA) - Standard PyTorch training No IPEX, no OpenVINO, no torch.compile - just reliable training. """ import os import random import re import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, ) from peft import LoraConfig, get_peft_model, TaskType from trl import GRPOConfig, GRPOTrainer from datasets import Dataset # ============================================================================ # CONFIG # ============================================================================ BASE_MODEL = "Qwen/Qwen3-0.6B-Base" OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v7" MAX_STEPS = 50 NUM_SAMPLES = 500 BATCH_SIZE = 4 NUM_GENERATIONS = 4 # LoRA Config LORA_R = 16 LORA_ALPHA = 32 LORA_DROPOUT = 0.05 # ============================================================================ # DATA GENERATION # ============================================================================ def generate_arithmetic_samples(n_samples): """Generate simple arithmetic problems""" samples = [] for _ in range(n_samples): op = random.choice(['+', '-']) if op == '+': a = random.randint(1, 50) b = random.randint(1, 50) answer = a + b else: a = random.randint(10, 100) b = random.randint(1, a) answer = a - b prompt = f"Calculate: {a} {op} {b} = " samples.append({ "prompt": prompt, "answer": str(answer) }) return samples # ============================================================================ # REWARD FUNCTION # ============================================================================ def extract_number(text): """Extract number from text, handling LaTeX format""" # Priority 1: Numbers in $$...$$ blocks (LaTeX) latex_match = re.search(r'\$\$(\d+(?:\.\d+)?)\$\$', text) if latex_match: return latex_match.group(1) # Priority 2: Numbers after "Answer:" answer_match = re.search(r'Answer:\s*(\d+(?:\.\d+)?)', text, re.IGNORECASE) if answer_match: return answer_match.group(1) # Priority 3: Last number in text numbers = re.findall(r'\d+(?:\.\d+)?', text) if numbers: return numbers[-1] return None def reward_func(completions, prompts, **kwargs): """Reward function for arithmetic tasks""" # Get ground truth ground_truth = kwargs.get('ground_truth', kwargs.get('answer', kwargs.get('solution', None))) if ground_truth is None: return [0.0] * len(completions) rewards = [] for completion, truth in zip(completions, ground_truth): # Handle list format (conversational) if isinstance(completion, list): text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion]) else: text = str(completion) # Extract predicted number predicted = extract_number(text) # Calculate reward if predicted is not None and str(predicted) == str(truth): rewards.append(1.0) else: rewards.append(0.0) return rewards # ============================================================================ # MAIN # ============================================================================ def main(): print("=" * 70) print("šŸš€ GRPO + RLVR v7 - Clean & Simple") print("=" * 70) print(f"Base Model: {BASE_MODEL}") print(f"Output: {OUTPUT_MODEL}") print(f"Steps: {MAX_STEPS}") print("=" * 70) # Check CPU threads print(f"\nšŸ“Š CPU Threads: {os.cpu_count()}") # Show optimizations print("\nšŸ“Š Optimizations:") print(f" 4-bit Quantization: āœ…") print(f" LoRA Adapters: āœ… (R={LORA_R})") print(f" IPEX: āŒ (skipped for stability)") print(f" OpenVINO: āŒ (skipped for stability)") print(f" torch.compile: āŒ (skipped for stability)") print("=" * 70) # Load tokenizer print("\nšŸ“¦ Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) tokenizer.pad_token = tokenizer.eos_token # Load model with quantization print("\nšŸ“¦ Loading model with 4-bit quantization...") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, quantization_config=quantization_config, device_map="auto", trust_remote_code=True, ) # Add LoRA adapters print("\nšŸ“¦ Adding LoRA adapters...") lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], bias="none", ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Generate training data print(f"\nšŸ“Š Generating {NUM_SAMPLES} training samples...") samples = generate_arithmetic_samples(NUM_SAMPLES) # Create dataset dataset = Dataset.from_list([ { "prompt": s["prompt"], "ground_truth": s["answer"], } for s in samples ]) # GRPO Config training_args = GRPOConfig( output_dir="./results", num_train_epochs=1, max_steps=MAX_STEPS, per_device_train_batch_size=BATCH_SIZE, gradient_accumulation_steps=2, num_generations=NUM_GENERATIONS, learning_rate=5e-5, bf16=False, # CPU doesn't support BF16 fp16=False, # 4-bit quantization is enough gradient_checkpointing=True, optim="paged_adamw_8bit", logging_steps=1, save_steps=25, save_total_limit=2, report_to="none", remove_unused_columns=False, ) # Create trainer print("\nšŸ“¦ Creating GRPO trainer...") trainer = GRPOTrainer( model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer, reward_funcs=[reward_func], ) # Train print("\nšŸš€ Starting GRPO Training...") trainer.train() # Save LoRA adapters print("\nšŸ“¦ Saving LoRA adapters...") model.save_pretrained(OUTPUT_MODEL) tokenizer.save_pretrained(OUTPUT_MODEL) # Push to Hub print(f"\nšŸ“¦ Pushing to Hub: {OUTPUT_MODEL}") model.push_to_hub(OUTPUT_MODEL, token=os.environ.get("HF_TOKEN")) tokenizer.push_to_hub(OUTPUT_MODEL, token=os.environ.get("HF_TOKEN")) print("\nāœ… Training complete!") print(f"Output: {OUTPUT_MODEL}") if __name__ == "__main__": main()