rlm-training-test / train_rlm.py
mindchain's picture
Add push_to_hub=True to save model
77cddac verified
#!/usr/bin/env python3
"""
GRPO Training for RLM Skills (Recursive Long-Context Reasoning)
Based on RLM Paper - Training model to find needles in haystacks
"""
import os
import re
import json
import random
import string
from datetime import datetime
import torch
from datasets import Dataset
from trl import GRPOTrainer, GRPOConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
# === CONFIG ===
MODEL_NAME = "Qwen/Qwen3-0.6B-Base"
MAX_STEPS = 20
BATCH_SIZE = 2
GROUP_SIZE = 4
LEARNING_RATE = 2e-4
MAX_COMPLETION_LENGTH = 64 # Short answers for needle finding
# LoRA
LORA_R = 16
LORA_ALPHA = 32
print("="*70)
print("🔍 RLM Training - Needle in Haystack")
print("="*70)
print(f"Model: {MODEL_NAME}")
print(f"Steps: {MAX_STEPS}")
print("Task: Find specific facts in long context")
print("="*70)
# === GENERATE SYNTHETIC NEEDLE-IN-HAYSTACK DATA ===
def generate_needle_haystack(num_samples=100, context_length=2000, seed=42):
"""
Generate synthetic needle-in-haystack data.
Long context with a hidden fact (needle) to find.
"""
random.seed(seed)
needles = [
("The secret code is", "XK7M9"),
("The magic number is", "42"),
("The password is", "quantum2026"),
("The answer is", "17"),
("The key value is", "alpha-beta-7"),
("The hidden word is", "serendipity"),
("The special ID is", "ID-847291"),
("The unique code is", "ZEBRA-99"),
("The mystery number is", "314159"),
("The secret phrase is", "blue moon rising"),
]
samples = []
for i in range(num_samples):
# Pick a random needle
prefix, needle = random.choice(needles)
# Generate haystack (random text)
words = []
for _ in range(context_length // 5): # Approx words
word_len = random.randint(3, 10)
word = ''.join(random.choices(string.ascii_lowercase, k=word_len))
words.append(word)
haystack = ' '.join(words)
# Insert needle at random position
insert_pos = random.randint(len(haystack) // 4, 3 * len(haystack) // 4)
context = haystack[:insert_pos] + f" {prefix} {needle}. " + haystack[insert_pos:]
# Create prompt
prompt = f"""Find the hidden information in this text.
The text contains a secret piece of information. Find it and report ONLY the value, nothing else.
Text:
{context}
What is the hidden value?"""
samples.append({
"prompt": prompt,
"needle": needle,
"needle_prefix": prefix,
})
return samples
print("\n📊 Generating Needle-in-Haystack dataset...")
data = generate_needle_haystack(num_samples=100)
print(f"✅ Generated {len(data)} samples")
# Convert to dataset
dataset_dict = {
"prompt": [s["prompt"] for s in data],
"needle": [s["needle"] for s in data],
"needle_prefix": [s["needle_prefix"] for s in data],
}
dataset = Dataset.from_dict(dataset_dict)
# === REWARD FUNCTION ===
def extract_needle(text, needle_prefixes):
"""Extract the needle value from model output"""
text = text.strip()
# Try to find pattern like "is X" or just the value
for prefix in needle_prefixes:
pattern = rf"{re.escape(prefix)}\s+(\S+)"
match = re.search(pattern, text, re.IGNORECASE)
if match:
return match.group(1).rstrip('.,;:')
# Fallback: last non-common word
words = text.split()
common_words = {'the', 'is', 'a', 'an', 'it', 'to', 'and', 'or', 'in', 'on', 'at'}
for word in reversed(words):
clean = word.strip('.,;:').lower()
if clean not in common_words and len(clean) > 1:
return word.strip('.,;:')
return text.split()[-1].strip('.,;:') if text else ""
def reward_func(completions, prompts, **kwargs):
"""
RLM Reward Function - rewards finding the correct needle
"""
needles = kwargs.get('needle', None)
if needles is None:
print("[WARN] No needle found in kwargs")
return [0.0] * len(completions)
# All possible needle prefixes
needle_prefixes = [
"The secret code is", "The magic number is", "The password is",
"The answer is", "The key value is", "The hidden word is",
"The special ID is", "The unique code is", "The mystery number is",
"The secret phrase is"
]
rewards = []
for i, (completion, truth) in enumerate(zip(completions, needles)):
# Handle list format
if isinstance(completion, list):
text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion])
else:
text = str(completion)
# Extract prediction
pred = extract_needle(text, needle_prefixes)
# Ground truth
if isinstance(truth, list):
truth_val = truth[0] if truth else ""
else:
truth_val = str(truth)
# Compare (case-insensitive, strip whitespace)
is_correct = pred.lower().strip() == truth_val.lower().strip()
reward = 1.0 if is_correct else 0.0
rewards.append(reward)
if i < 2: # Debug first 2
print(f" [{i}] pred='{pred}' truth='{truth_val}' r={reward}")
print(f"[REWARD] Avg: {sum(rewards)/len(rewards):.2f} | {rewards[:4]}")
return rewards
# === LOAD MODEL ===
print(f"\n📦 Loading model: {MODEL_NAME}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("✅ Model loaded")
# === PREPARE FOR TRAINING ===
if device == "cuda":
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# === GRPO CONFIG ===
training_args = GRPOConfig(
output_dir="./output",
max_steps=MAX_STEPS,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=2,
num_generations=GROUP_SIZE,
max_completion_length=MAX_COMPLETION_LENGTH,
learning_rate=LEARNING_RATE,
beta=0.0, # No KL for this task
bf16=device == "cuda",
fp16=False,
gradient_checkpointing=True,
optim="adamw_torch_fused" if device == "cuda" else "adamw_torch",
logging_steps=1,
save_steps=10,
dataloader_num_workers=0,
remove_unused_columns=False,
# PUSH TO HUB!
push_to_hub=True,
hub_model_id="mindchain/qwen3-0.6b-rlm-needle",
)
# === TRAINING LOG ===
training_log = {
"model": MODEL_NAME,
"task": "needle_in_haystack",
"steps": MAX_STEPS,
"start_time": datetime.now().isoformat(),
"rewards": [],
"losses": [],
}
# === TRAIN ===
print(f"\n🚀 Starting RLM training for {MAX_STEPS} steps...")
print("-"*70)
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
reward_funcs=reward_func,
processing_class=tokenizer,
)
try:
trainer.train()
training_log["end_time"] = datetime.now().isoformat()
training_log["status"] = "completed"
with open("training_log.json", "w") as f:
json.dump(training_log, f, indent=2)
print("\n" + "="*70)
print("✅ RLM Training Complete!")
print("="*70)
print(f"Log saved to: training_log.json")
except Exception as e:
print(f"\n❌ Training failed: {e}")
training_log["status"] = "failed"
training_log["error"] = str(e)
with open("training_log.json", "w") as f:
json.dump(training_log, f, indent=2)
raise