rlm-arithmetic-training / train_arithmetic_v7_clean.py
mindchain's picture
Upload train_arithmetic_v7_clean.py with huggingface_hub
8b46b16 verified
#!/usr/bin/env python3
"""
GRPO + RLVR Training Script v7 - Clean & Simple
Just the essentials:
- 4-bit Quantization (BitsAndBytes)
- LoRA Adapters (QLoRA)
- Standard PyTorch training
No IPEX, no OpenVINO, no torch.compile - just reliable training.
"""
import os
import random
import re
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, TaskType
from trl import GRPOConfig, GRPOTrainer
from datasets import Dataset
# ============================================================================
# CONFIG
# ============================================================================
BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
OUTPUT_MODEL = "mindchain/qwen3-0.6b-arithmetic-v7"
MAX_STEPS = 50
NUM_SAMPLES = 500
BATCH_SIZE = 4
NUM_GENERATIONS = 4
# LoRA Config
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
# ============================================================================
# DATA GENERATION
# ============================================================================
def generate_arithmetic_samples(n_samples):
"""Generate simple arithmetic problems"""
samples = []
for _ in range(n_samples):
op = random.choice(['+', '-'])
if op == '+':
a = random.randint(1, 50)
b = random.randint(1, 50)
answer = a + b
else:
a = random.randint(10, 100)
b = random.randint(1, a)
answer = a - b
prompt = f"Calculate: {a} {op} {b} = "
samples.append({
"prompt": prompt,
"answer": str(answer)
})
return samples
# ============================================================================
# REWARD FUNCTION
# ============================================================================
def extract_number(text):
"""Extract number from text, handling LaTeX format"""
# Priority 1: Numbers in $$...$$ blocks (LaTeX)
latex_match = re.search(r'\$\$(\d+(?:\.\d+)?)\$\$', text)
if latex_match:
return latex_match.group(1)
# Priority 2: Numbers after "Answer:"
answer_match = re.search(r'Answer:\s*(\d+(?:\.\d+)?)', text, re.IGNORECASE)
if answer_match:
return answer_match.group(1)
# Priority 3: Last number in text
numbers = re.findall(r'\d+(?:\.\d+)?', text)
if numbers:
return numbers[-1]
return None
def reward_func(completions, prompts, **kwargs):
"""Reward function for arithmetic tasks"""
# Get ground truth
ground_truth = kwargs.get('ground_truth', kwargs.get('answer', kwargs.get('solution', None)))
if ground_truth is None:
return [0.0] * len(completions)
rewards = []
for completion, truth in zip(completions, ground_truth):
# Handle list format (conversational)
if isinstance(completion, list):
text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
else:
text = str(completion)
# Extract predicted number
predicted = extract_number(text)
# Calculate reward
if predicted is not None and str(predicted) == str(truth):
rewards.append(1.0)
else:
rewards.append(0.0)
return rewards
# ============================================================================
# MAIN
# ============================================================================
def main():
print("=" * 70)
print("πŸš€ GRPO + RLVR v7 - Clean & Simple")
print("=" * 70)
print(f"Base Model: {BASE_MODEL}")
print(f"Output: {OUTPUT_MODEL}")
print(f"Steps: {MAX_STEPS}")
print("=" * 70)
# Check CPU threads
print(f"\nπŸ“Š CPU Threads: {os.cpu_count()}")
# Show optimizations
print("\nπŸ“Š Optimizations:")
print(f" 4-bit Quantization: βœ…")
print(f" LoRA Adapters: βœ… (R={LORA_R})")
print(f" IPEX: ❌ (skipped for stability)")
print(f" OpenVINO: ❌ (skipped for stability)")
print(f" torch.compile: ❌ (skipped for stability)")
print("=" * 70)
# Load tokenizer
print("\nπŸ“¦ Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
# Load model with quantization
print("\nπŸ“¦ Loading model with 4-bit quantization...")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True,
)
# Add LoRA adapters
print("\nπŸ“¦ Adding LoRA adapters...")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=LORA_R,
lora_alpha=LORA_ALPHA,
lora_dropout=LORA_DROPOUT,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Generate training data
print(f"\nπŸ“Š Generating {NUM_SAMPLES} training samples...")
samples = generate_arithmetic_samples(NUM_SAMPLES)
# Create dataset
dataset = Dataset.from_list([
{
"prompt": s["prompt"],
"ground_truth": s["answer"],
}
for s in samples
])
# GRPO Config
training_args = GRPOConfig(
output_dir="./results",
num_train_epochs=1,
max_steps=MAX_STEPS,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=2,
num_generations=NUM_GENERATIONS,
learning_rate=5e-5,
bf16=False, # CPU doesn't support BF16
fp16=False, # 4-bit quantization is enough
gradient_checkpointing=True,
optim="paged_adamw_8bit",
logging_steps=1,
save_steps=25,
save_total_limit=2,
report_to="none",
remove_unused_columns=False,
)
# Create trainer
print("\nπŸ“¦ Creating GRPO trainer...")
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
reward_funcs=[reward_func],
)
# Train
print("\nπŸš€ Starting GRPO Training...")
trainer.train()
# Save LoRA adapters
print("\nπŸ“¦ Saving LoRA adapters...")
model.save_pretrained(OUTPUT_MODEL)
tokenizer.save_pretrained(OUTPUT_MODEL)
# Push to Hub
print(f"\nπŸ“¦ Pushing to Hub: {OUTPUT_MODEL}")
model.push_to_hub(OUTPUT_MODEL, token=os.environ.get("HF_TOKEN"))
tokenizer.push_to_hub(OUTPUT_MODEL, token=os.environ.get("HF_TOKEN"))
print("\nβœ… Training complete!")
print(f"Output: {OUTPUT_MODEL}")
if __name__ == "__main__":
main()