Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| GRPO + RLVR Training - v5 (Ultimate CPU Optimized + Quantized) | |
| Optimized for HF Spaces CPU with 4-bit quantization | |
| Features: | |
| - 4-bit Quantization (BitsAndBytes) - faster inference | |
| - LoRA Adapters (QLoRA) - efficient training | |
| - Intel Extension for PyTorch (IPEX) - CPU optimization | |
| - torch.compile() JIT compilation | |
| - BetterTransformer (optimized attention) | |
| - LaTeX-aware answer extraction | |
| - All optimizations combined! | |
| """ | |
| import os | |
| import re | |
| import random | |
| import torch | |
| from datasets import Dataset | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| ) | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| from trl import GRPOConfig, GRPOTrainer | |
| # ============================================================================ | |
| # OPTIMIZATION FLAGS | |
| # ============================================================================ | |
| USE_IPEX = False | |
| USE_COMPILE = hasattr(torch, 'compile') | |
| USE_BETTER_TRANSFORMER = False | |
| USE_QUANTIZATION = True # Enable 4-bit quantization | |
| try: | |
| import intel_extension_for_pytorch as ipex | |
| USE_IPEX = True | |
| print("β IPEX available") | |
| except Exception as e: | |
| print(f"β οΈ IPEX not available: {e}") | |
| try: | |
| from optimum.bettertransformer import BetterTransformer | |
| USE_BETTER_TRANSFORMER = True | |
| print("β BetterTransformer available") | |
| except Exception as e: | |
| print(f"β οΈ BetterTransformer not available: {e}") | |
| # ============================================================================ | |
| # CONFIG | |
| # ============================================================================ | |
| BASE_MODEL = "Qwen/Qwen3-0.6B-Base" | |
| OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v5-quantized" | |
| MAX_STEPS = 50 | |
| NUM_SAMPLES = 500 | |
| BATCH_SIZE = 4 # Larger batch with quantization | |
| NUM_GENERATIONS = 4 # More generations | |
| # LoRA Config | |
| LORA_R = 16 | |
| LORA_ALPHA = 32 | |
| LORA_DROPOUT = 0.05 | |
| # Quantization Config | |
| USE_4BIT = True # Use 4-bit quantization | |
| # ============================================================================ | |
| # DATA GENERATION | |
| # ============================================================================ | |
| def generate_arithmetic_samples(n_samples): | |
| """Generate simple arithmetic problems""" | |
| samples = [] | |
| for _ in range(n_samples): | |
| op = random.choice(['+', '-']) | |
| if op == '+': | |
| a = random.randint(10, 99) | |
| b = random.randint(10, 99) | |
| answer = a + b | |
| problem = f"{a} + {b} = ?" | |
| else: | |
| a = random.randint(20, 99) | |
| b = random.randint(10, a-1) | |
| answer = a - b | |
| problem = f"{a} - {b} = ?" | |
| samples.append({ | |
| 'prompt': f"Solve: {problem}\nAnswer:", | |
| 'answer': str(answer), | |
| }) | |
| return samples | |
| # ============================================================================ | |
| # REWARD FUNCTION (LaTeX-aware) | |
| # ============================================================================ | |
| def extract_answer(text): | |
| """ | |
| Extract the final answer from model output. | |
| Priority: | |
| 1. Number in $$...$$ LaTeX blocks | |
| 2. Number after "Answer:" pattern | |
| 3. Last standalone number (fallback) | |
| """ | |
| # Try LaTeX blocks first | |
| latex_blocks = re.findall(r'\$\$(.*?)\$\$', text, re.DOTALL) | |
| if latex_blocks: | |
| last_block = latex_blocks[-1] | |
| numbers = re.findall(r'-?\d+\.?\d*', last_block) | |
| if numbers: | |
| return numbers[-1].strip() | |
| # Try "Answer:" pattern | |
| answer_match = re.search(r'Answer:\s*(-?\d+\.?\d*)', text, re.IGNORECASE) | |
| if answer_match: | |
| return answer_match.group(1).strip() | |
| # Fallback: last number | |
| numbers = re.findall(r'-?\d+\.?\d*', text) | |
| if numbers: | |
| return numbers[-1].strip() | |
| return "" | |
| def reward_func(completions, prompts=None, **kwargs): | |
| """Reward function with LaTeX-aware extraction.""" | |
| answers = None | |
| for key in ['answer', 'ground_truth', 'solution', 'label']: | |
| if key in kwargs and kwargs[key] is not None: | |
| answers = kwargs[key] | |
| break | |
| if answers is None: | |
| return [0.0] * len(completions) | |
| rewards = [] | |
| for i, (completion, truth) in enumerate(zip(completions, answers)): | |
| if isinstance(completion, list): | |
| text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion]) | |
| else: | |
| text = str(completion) | |
| predicted = extract_answer(text) | |
| is_correct = predicted == str(truth).strip() | |
| rewards.append(1.0 if is_correct else 0.0) | |
| if i < 2: | |
| status = "β " if is_correct else "β" | |
| print(f" [{i+1}] {status} Truth={truth} | Pred={predicted}") | |
| return rewards | |
| # ============================================================================ | |
| # MAIN TRAINING | |
| # ============================================================================ | |
| def main(): | |
| print("="*70) | |
| print("π GRPO + RLVR v5 - Ultimate CPU Optimized + Quantized") | |
| print("="*70) | |
| print(f"Base Model: {BASE_MODEL}") | |
| print(f"Output: {OUTPUT_MODEL}") | |
| print(f"Steps: {MAX_STEPS}") | |
| print("="*70) | |
| # Print optimization status | |
| print("\nπ Optimizations:") | |
| print(f" 4-bit Quantization: {'β ' if USE_4BIT else 'β'}") | |
| print(f" LoRA Adapters: β (R={LORA_R})") | |
| print(f" IPEX: {'β ' if USE_IPEX else 'β'}") | |
| print(f" torch.compile: {'β ' if USE_COMPILE else 'β'}") | |
| print(f" BetterTransformer: {'β ' if USE_BETTER_TRANSFORMER else 'β'}") | |
| print("="*70 + "\n") | |
| # CPU optimization | |
| torch.set_num_threads(os.cpu_count() or 4) | |
| print(f"π CPU Threads: {torch.get_num_threads()}\n") | |
| # Load tokenizer | |
| print("π¦ Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Quantization config | |
| if USE_4BIT: | |
| print("\nπ¦ Loading model with 4-bit quantization...") | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float32, # CPU uses float32 | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| ) | |
| print(" Model loaded in 4-bit!") | |
| except Exception as e: | |
| print(f" β οΈ 4-bit failed: {e}") | |
| print(" Falling back to FP32...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float32, | |
| ) | |
| else: | |
| print("\nπ¦ Loading model in FP32...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float32, | |
| ) | |
| # Add LoRA adapters | |
| print("\nπ§ Adding LoRA adapters...") | |
| lora_config = LoraConfig( | |
| r=LORA_R, | |
| lora_alpha=LORA_ALPHA, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| lora_dropout=LORA_DROPOUT, | |
| bias="none", | |
| task_type=TaskType.CAUSAL_LM, | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| # Apply IPEX | |
| if USE_IPEX: | |
| print("\nπ§ Applying IPEX...") | |
| try: | |
| # Note: IPEX with PEFT models may need special handling | |
| model = ipex.optimize(model, dtype=torch.float32) | |
| print(" IPEX applied!") | |
| except Exception as e: | |
| print(f" β οΈ IPEX failed: {e}") | |
| # Apply BetterTransformer | |
| if USE_BETTER_TRANSFORMER: | |
| print("\nπ§ Applying BetterTransformer...") | |
| try: | |
| model = BetterTransformer.transform(model) | |
| print(" BetterTransformer applied!") | |
| except Exception as e: | |
| print(f" β οΈ BetterTransformer failed: {e}") | |
| # Generate training data | |
| print("\nπ Generating training data...") | |
| train_samples = generate_arithmetic_samples(NUM_SAMPLES) | |
| train_dataset = Dataset.from_list(train_samples) | |
| print(f"β {len(train_dataset)} training samples\n") | |
| # GRPO Config | |
| training_args = GRPOConfig( | |
| output_dir="./outputs", | |
| max_steps=MAX_STEPS, | |
| per_device_train_batch_size=BATCH_SIZE, | |
| num_generations=NUM_GENERATIONS, | |
| learning_rate=2e-4, | |
| beta=0.0, | |
| bf16=False, | |
| fp16=False, | |
| gradient_checkpointing=False, | |
| optim="adamw_torch", | |
| logging_steps=1, | |
| save_steps=MAX_STEPS, | |
| push_to_hub=False, | |
| report_to="none", | |
| dataloader_num_workers=0, | |
| dataloader_pin_memory=False, | |
| ) | |
| print("π Starting GRPO Training...") | |
| print("="*70 + "\n") | |
| # Create trainer | |
| trainer = GRPOTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| reward_funcs=[reward_func], | |
| ) | |
| # Apply torch.compile | |
| if USE_COMPILE: | |
| print("π§ Applying torch.compile()...") | |
| try: | |
| trainer.model = torch.compile(trainer.model) | |
| print(" torch.compile() applied!\n") | |
| except Exception as e: | |
| print(f" β οΈ torch.compile() failed: {e}\n") | |
| # Train | |
| trainer.train() | |
| print("\n" + "="*70) | |
| print("β Training complete!") | |
| print("="*70) | |
| # Save LoRA adapters | |
| print(f"\nπ¦ Saving LoRA adapters to: {OUTPUT_MODEL}") | |
| model.save_pretrained(OUTPUT_MODEL) | |
| tokenizer.save_pretrained(OUTPUT_MODEL) | |
| # Push to Hub | |
| print(f"\nπ¦ Pushing to Hub: {OUTPUT_MODEL}") | |
| model.push_to_hub(OUTPUT_MODEL) | |
| tokenizer.push_to_hub(OUTPUT_MODEL) | |
| print(f"β Model pushed to: https://huggingface.co/{OUTPUT_MODEL}") | |
| if __name__ == "__main__": | |
| main() | |