rlm-arithmetic-training / train_arithmetic_v5_ultimate.py
mindchain's picture
Upload train_arithmetic_v5_ultimate.py with huggingface_hub
bf5bd77 verified
#!/usr/bin/env python3
"""
GRPO + RLVR Training - v5 (Ultimate CPU Optimized + Quantized)
Optimized for HF Spaces CPU with 4-bit quantization
Features:
- 4-bit Quantization (BitsAndBytes) - faster inference
- LoRA Adapters (QLoRA) - efficient training
- Intel Extension for PyTorch (IPEX) - CPU optimization
- torch.compile() JIT compilation
- BetterTransformer (optimized attention)
- LaTeX-aware answer extraction
- All optimizations combined!
"""
import os
import re
import random
import torch
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, TaskType
from trl import GRPOConfig, GRPOTrainer
# ============================================================================
# OPTIMIZATION FLAGS
# ============================================================================
USE_IPEX = False
USE_COMPILE = hasattr(torch, 'compile')
USE_BETTER_TRANSFORMER = False
USE_QUANTIZATION = True # Enable 4-bit quantization
try:
import intel_extension_for_pytorch as ipex
USE_IPEX = True
print("βœ… IPEX available")
except Exception as e:
print(f"⚠️ IPEX not available: {e}")
try:
from optimum.bettertransformer import BetterTransformer
USE_BETTER_TRANSFORMER = True
print("βœ… BetterTransformer available")
except Exception as e:
print(f"⚠️ BetterTransformer not available: {e}")
# ============================================================================
# CONFIG
# ============================================================================
BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v5-quantized"
MAX_STEPS = 50
NUM_SAMPLES = 500
BATCH_SIZE = 4 # Larger batch with quantization
NUM_GENERATIONS = 4 # More generations
# LoRA Config
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
# Quantization Config
USE_4BIT = True # Use 4-bit quantization
# ============================================================================
# DATA GENERATION
# ============================================================================
def generate_arithmetic_samples(n_samples):
"""Generate simple arithmetic problems"""
samples = []
for _ in range(n_samples):
op = random.choice(['+', '-'])
if op == '+':
a = random.randint(10, 99)
b = random.randint(10, 99)
answer = a + b
problem = f"{a} + {b} = ?"
else:
a = random.randint(20, 99)
b = random.randint(10, a-1)
answer = a - b
problem = f"{a} - {b} = ?"
samples.append({
'prompt': f"Solve: {problem}\nAnswer:",
'answer': str(answer),
})
return samples
# ============================================================================
# REWARD FUNCTION (LaTeX-aware)
# ============================================================================
def extract_answer(text):
"""
Extract the final answer from model output.
Priority:
1. Number in $$...$$ LaTeX blocks
2. Number after "Answer:" pattern
3. Last standalone number (fallback)
"""
# Try LaTeX blocks first
latex_blocks = re.findall(r'\$\$(.*?)\$\$', text, re.DOTALL)
if latex_blocks:
last_block = latex_blocks[-1]
numbers = re.findall(r'-?\d+\.?\d*', last_block)
if numbers:
return numbers[-1].strip()
# Try "Answer:" pattern
answer_match = re.search(r'Answer:\s*(-?\d+\.?\d*)', text, re.IGNORECASE)
if answer_match:
return answer_match.group(1).strip()
# Fallback: last number
numbers = re.findall(r'-?\d+\.?\d*', text)
if numbers:
return numbers[-1].strip()
return ""
def reward_func(completions, prompts=None, **kwargs):
"""Reward function with LaTeX-aware extraction."""
answers = None
for key in ['answer', 'ground_truth', 'solution', 'label']:
if key in kwargs and kwargs[key] is not None:
answers = kwargs[key]
break
if answers is None:
return [0.0] * len(completions)
rewards = []
for i, (completion, truth) in enumerate(zip(completions, answers)):
if isinstance(completion, list):
text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
else:
text = str(completion)
predicted = extract_answer(text)
is_correct = predicted == str(truth).strip()
rewards.append(1.0 if is_correct else 0.0)
if i < 2:
status = "βœ…" if is_correct else "❌"
print(f" [{i+1}] {status} Truth={truth} | Pred={predicted}")
return rewards
# ============================================================================
# MAIN TRAINING
# ============================================================================
def main():
print("="*70)
print("πŸš€ GRPO + RLVR v5 - Ultimate CPU Optimized + Quantized")
print("="*70)
print(f"Base Model: {BASE_MODEL}")
print(f"Output: {OUTPUT_MODEL}")
print(f"Steps: {MAX_STEPS}")
print("="*70)
# Print optimization status
print("\nπŸ“Š Optimizations:")
print(f" 4-bit Quantization: {'βœ…' if USE_4BIT else '❌'}")
print(f" LoRA Adapters: βœ… (R={LORA_R})")
print(f" IPEX: {'βœ…' if USE_IPEX else '❌'}")
print(f" torch.compile: {'βœ…' if USE_COMPILE else '❌'}")
print(f" BetterTransformer: {'βœ…' if USE_BETTER_TRANSFORMER else '❌'}")
print("="*70 + "\n")
# CPU optimization
torch.set_num_threads(os.cpu_count() or 4)
print(f"πŸ“Š CPU Threads: {torch.get_num_threads()}\n")
# Load tokenizer
print("πŸ“¦ Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Quantization config
if USE_4BIT:
print("\nπŸ“¦ Loading model with 4-bit quantization...")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float32, # CPU uses float32
bnb_4bit_use_double_quant=True,
)
try:
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=quantization_config,
device_map="auto",
)
print(" Model loaded in 4-bit!")
except Exception as e:
print(f" ⚠️ 4-bit failed: {e}")
print(" Falling back to FP32...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float32,
)
else:
print("\nπŸ“¦ Loading model in FP32...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float32,
)
# Add LoRA adapters
print("\nπŸ”§ Adding LoRA adapters...")
lora_config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=LORA_DROPOUT,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Apply IPEX
if USE_IPEX:
print("\nπŸ”§ Applying IPEX...")
try:
# Note: IPEX with PEFT models may need special handling
model = ipex.optimize(model, dtype=torch.float32)
print(" IPEX applied!")
except Exception as e:
print(f" ⚠️ IPEX failed: {e}")
# Apply BetterTransformer
if USE_BETTER_TRANSFORMER:
print("\nπŸ”§ Applying BetterTransformer...")
try:
model = BetterTransformer.transform(model)
print(" BetterTransformer applied!")
except Exception as e:
print(f" ⚠️ BetterTransformer failed: {e}")
# Generate training data
print("\nπŸ“Š Generating training data...")
train_samples = generate_arithmetic_samples(NUM_SAMPLES)
train_dataset = Dataset.from_list(train_samples)
print(f"βœ… {len(train_dataset)} training samples\n")
# GRPO Config
training_args = GRPOConfig(
output_dir="./outputs",
max_steps=MAX_STEPS,
per_device_train_batch_size=BATCH_SIZE,
num_generations=NUM_GENERATIONS,
learning_rate=2e-4,
beta=0.0,
bf16=False,
fp16=False,
gradient_checkpointing=False,
optim="adamw_torch",
logging_steps=1,
save_steps=MAX_STEPS,
push_to_hub=False,
report_to="none",
dataloader_num_workers=0,
dataloader_pin_memory=False,
)
print("πŸš€ Starting GRPO Training...")
print("="*70 + "\n")
# Create trainer
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
reward_funcs=[reward_func],
)
# Apply torch.compile
if USE_COMPILE:
print("πŸ”§ Applying torch.compile()...")
try:
trainer.model = torch.compile(trainer.model)
print(" torch.compile() applied!\n")
except Exception as e:
print(f" ⚠️ torch.compile() failed: {e}\n")
# Train
trainer.train()
print("\n" + "="*70)
print("βœ… Training complete!")
print("="*70)
# Save LoRA adapters
print(f"\nπŸ“¦ Saving LoRA adapters to: {OUTPUT_MODEL}")
model.save_pretrained(OUTPUT_MODEL)
tokenizer.save_pretrained(OUTPUT_MODEL)
# Push to Hub
print(f"\nπŸ“¦ Pushing to Hub: {OUTPUT_MODEL}")
model.push_to_hub(OUTPUT_MODEL)
tokenizer.push_to_hub(OUTPUT_MODEL)
print(f"βœ… Model pushed to: https://huggingface.co/{OUTPUT_MODEL}")
if __name__ == "__main__":
main()