#!/usr/bin/env python3 """ GRPO + RLVR Training - v5 (Ultimate CPU Optimized + Quantized) Optimized for HF Spaces CPU with 4-bit quantization Features: - 4-bit Quantization (BitsAndBytes) - faster inference - LoRA Adapters (QLoRA) - efficient training - Intel Extension for PyTorch (IPEX) - CPU optimization - torch.compile() JIT compilation - BetterTransformer (optimized attention) - LaTeX-aware answer extraction - All optimizations combined! """ import os import re import random import torch from datasets import Dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, ) from peft import LoraConfig, get_peft_model, TaskType from trl import GRPOConfig, GRPOTrainer # ============================================================================ # OPTIMIZATION FLAGS # ============================================================================ USE_IPEX = False USE_COMPILE = hasattr(torch, 'compile') USE_BETTER_TRANSFORMER = False USE_QUANTIZATION = True # Enable 4-bit quantization try: import intel_extension_for_pytorch as ipex USE_IPEX = True print("✅ IPEX available") except Exception as e: print(f"⚠️ IPEX not available: {e}") try: from optimum.bettertransformer import BetterTransformer USE_BETTER_TRANSFORMER = True print("✅ BetterTransformer available") except Exception as e: print(f"⚠️ BetterTransformer not available: {e}") # ============================================================================ # CONFIG # ============================================================================ BASE_MODEL = "Qwen/Qwen3-0.6B-Base" OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v5-quantized" MAX_STEPS = 50 NUM_SAMPLES = 500 BATCH_SIZE = 4 # Larger batch with quantization NUM_GENERATIONS = 4 # More generations # LoRA Config LORA_R = 16 LORA_ALPHA = 32 LORA_DROPOUT = 0.05 # Quantization Config USE_4BIT = True # Use 4-bit quantization # ============================================================================ # DATA GENERATION # ============================================================================ def generate_arithmetic_samples(n_samples): """Generate simple arithmetic problems""" samples = [] for _ in range(n_samples): op = random.choice(['+', '-']) if op == '+': a = random.randint(10, 99) b = random.randint(10, 99) answer = a + b problem = f"{a} + {b} = ?" else: a = random.randint(20, 99) b = random.randint(10, a-1) answer = a - b problem = f"{a} - {b} = ?" samples.append({ 'prompt': f"Solve: {problem}\nAnswer:", 'answer': str(answer), }) return samples # ============================================================================ # REWARD FUNCTION (LaTeX-aware) # ============================================================================ def extract_answer(text): """ Extract the final answer from model output. Priority: 1. Number in $$...$$ LaTeX blocks 2. Number after "Answer:" pattern 3. Last standalone number (fallback) """ # Try LaTeX blocks first latex_blocks = re.findall(r'\$\$(.*?)\$\$', text, re.DOTALL) if latex_blocks: last_block = latex_blocks[-1] numbers = re.findall(r'-?\d+\.?\d*', last_block) if numbers: return numbers[-1].strip() # Try "Answer:" pattern answer_match = re.search(r'Answer:\s*(-?\d+\.?\d*)', text, re.IGNORECASE) if answer_match: return answer_match.group(1).strip() # Fallback: last number numbers = re.findall(r'-?\d+\.?\d*', text) if numbers: return numbers[-1].strip() return "" def reward_func(completions, prompts=None, **kwargs): """Reward function with LaTeX-aware extraction.""" answers = None for key in ['answer', 'ground_truth', 'solution', 'label']: if key in kwargs and kwargs[key] is not None: answers = kwargs[key] break if answers is None: return [0.0] * len(completions) rewards = [] for i, (completion, truth) in enumerate(zip(completions, answers)): if isinstance(completion, list): text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion]) else: text = str(completion) predicted = extract_answer(text) is_correct = predicted == str(truth).strip() rewards.append(1.0 if is_correct else 0.0) if i < 2: status = "✅" if is_correct else "❌" print(f" [{i+1}] {status} Truth={truth} | Pred={predicted}") return rewards # ============================================================================ # MAIN TRAINING # ============================================================================ def main(): print("="*70) print("🚀 GRPO + RLVR v5 - Ultimate CPU Optimized + Quantized") print("="*70) print(f"Base Model: {BASE_MODEL}") print(f"Output: {OUTPUT_MODEL}") print(f"Steps: {MAX_STEPS}") print("="*70) # Print optimization status print("\n📊 Optimizations:") print(f" 4-bit Quantization: {'✅' if USE_4BIT else '❌'}") print(f" LoRA Adapters: ✅ (R={LORA_R})") print(f" IPEX: {'✅' if USE_IPEX else '❌'}") print(f" torch.compile: {'✅' if USE_COMPILE else '❌'}") print(f" BetterTransformer: {'✅' if USE_BETTER_TRANSFORMER else '❌'}") print("="*70 + "\n") # CPU optimization torch.set_num_threads(os.cpu_count() or 4) print(f"📊 CPU Threads: {torch.get_num_threads()}\n") # Load tokenizer print("📦 Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Quantization config if USE_4BIT: print("\n📦 Loading model with 4-bit quantization...") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float32, # CPU uses float32 bnb_4bit_use_double_quant=True, ) try: model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, quantization_config=quantization_config, device_map="auto", ) print(" Model loaded in 4-bit!") except Exception as e: print(f" ⚠️ 4-bit failed: {e}") print(" Falling back to FP32...") model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float32, ) else: print("\n📦 Loading model in FP32...") model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float32, ) # Add LoRA adapters print("\n🔧 Adding LoRA adapters...") lora_config = LoraConfig( r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], lora_dropout=LORA_DROPOUT, bias="none", task_type=TaskType.CAUSAL_LM, ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Apply IPEX if USE_IPEX: print("\n🔧 Applying IPEX...") try: # Note: IPEX with PEFT models may need special handling model = ipex.optimize(model, dtype=torch.float32) print(" IPEX applied!") except Exception as e: print(f" ⚠️ IPEX failed: {e}") # Apply BetterTransformer if USE_BETTER_TRANSFORMER: print("\n🔧 Applying BetterTransformer...") try: model = BetterTransformer.transform(model) print(" BetterTransformer applied!") except Exception as e: print(f" ⚠️ BetterTransformer failed: {e}") # Generate training data print("\n📊 Generating training data...") train_samples = generate_arithmetic_samples(NUM_SAMPLES) train_dataset = Dataset.from_list(train_samples) print(f"✅ {len(train_dataset)} training samples\n") # GRPO Config training_args = GRPOConfig( output_dir="./outputs", max_steps=MAX_STEPS, per_device_train_batch_size=BATCH_SIZE, num_generations=NUM_GENERATIONS, learning_rate=2e-4, beta=0.0, bf16=False, fp16=False, gradient_checkpointing=False, optim="adamw_torch", logging_steps=1, save_steps=MAX_STEPS, push_to_hub=False, report_to="none", dataloader_num_workers=0, dataloader_pin_memory=False, ) print("🚀 Starting GRPO Training...") print("="*70 + "\n") # Create trainer trainer = GRPOTrainer( model=model, args=training_args, train_dataset=train_dataset, reward_funcs=[reward_func], ) # Apply torch.compile if USE_COMPILE: print("🔧 Applying torch.compile()...") try: trainer.model = torch.compile(trainer.model) print(" torch.compile() applied!\n") except Exception as e: print(f" ⚠️ torch.compile() failed: {e}\n") # Train trainer.train() print("\n" + "="*70) print("✅ Training complete!") print("="*70) # Save LoRA adapters print(f"\n📦 Saving LoRA adapters to: {OUTPUT_MODEL}") model.save_pretrained(OUTPUT_MODEL) tokenizer.save_pretrained(OUTPUT_MODEL) # Push to Hub print(f"\n📦 Pushing to Hub: {OUTPUT_MODEL}") model.push_to_hub(OUTPUT_MODEL) tokenizer.push_to_hub(OUTPUT_MODEL) print(f"✅ Model pushed to: https://huggingface.co/{OUTPUT_MODEL}") if __name__ == "__main__": main()