#!/usr/bin/env python3 """ GRPO Training for RLM Skills (Recursive Long-Context Reasoning) Based on RLM Paper - Training model to find needles in haystacks """ import os import re import json import random import string from datetime import datetime import torch from datasets import Dataset from trl import GRPOTrainer, GRPOConfig from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer # === CONFIG === MODEL_NAME = "Qwen/Qwen3-0.6B-Base" MAX_STEPS = 20 BATCH_SIZE = 2 GROUP_SIZE = 4 LEARNING_RATE = 2e-4 MAX_COMPLETION_LENGTH = 64 # Short answers for needle finding # LoRA LORA_R = 16 LORA_ALPHA = 32 print("="*70) print("šŸ” RLM Training - Needle in Haystack") print("="*70) print(f"Model: {MODEL_NAME}") print(f"Steps: {MAX_STEPS}") print("Task: Find specific facts in long context") print("="*70) # === GENERATE SYNTHETIC NEEDLE-IN-HAYSTACK DATA === def generate_needle_haystack(num_samples=100, context_length=2000, seed=42): """ Generate synthetic needle-in-haystack data. Long context with a hidden fact (needle) to find. """ random.seed(seed) needles = [ ("The secret code is", "XK7M9"), ("The magic number is", "42"), ("The password is", "quantum2026"), ("The answer is", "17"), ("The key value is", "alpha-beta-7"), ("The hidden word is", "serendipity"), ("The special ID is", "ID-847291"), ("The unique code is", "ZEBRA-99"), ("The mystery number is", "314159"), ("The secret phrase is", "blue moon rising"), ] samples = [] for i in range(num_samples): # Pick a random needle prefix, needle = random.choice(needles) # Generate haystack (random text) words = [] for _ in range(context_length // 5): # Approx words word_len = random.randint(3, 10) word = ''.join(random.choices(string.ascii_lowercase, k=word_len)) words.append(word) haystack = ' '.join(words) # Insert needle at random position insert_pos = random.randint(len(haystack) // 4, 3 * len(haystack) // 4) context = haystack[:insert_pos] + f" {prefix} {needle}. " + haystack[insert_pos:] # Create prompt prompt = f"""Find the hidden information in this text. The text contains a secret piece of information. Find it and report ONLY the value, nothing else. Text: {context} What is the hidden value?""" samples.append({ "prompt": prompt, "needle": needle, "needle_prefix": prefix, }) return samples print("\nšŸ“Š Generating Needle-in-Haystack dataset...") data = generate_needle_haystack(num_samples=100) print(f"āœ… Generated {len(data)} samples") # Convert to dataset dataset_dict = { "prompt": [s["prompt"] for s in data], "needle": [s["needle"] for s in data], "needle_prefix": [s["needle_prefix"] for s in data], } dataset = Dataset.from_dict(dataset_dict) # === REWARD FUNCTION === def extract_needle(text, needle_prefixes): """Extract the needle value from model output""" text = text.strip() # Try to find pattern like "is X" or just the value for prefix in needle_prefixes: pattern = rf"{re.escape(prefix)}\s+(\S+)" match = re.search(pattern, text, re.IGNORECASE) if match: return match.group(1).rstrip('.,;:') # Fallback: last non-common word words = text.split() common_words = {'the', 'is', 'a', 'an', 'it', 'to', 'and', 'or', 'in', 'on', 'at'} for word in reversed(words): clean = word.strip('.,;:').lower() if clean not in common_words and len(clean) > 1: return word.strip('.,;:') return text.split()[-1].strip('.,;:') if text else "" def reward_func(completions, prompts, **kwargs): """ RLM Reward Function - rewards finding the correct needle """ needles = kwargs.get('needle', None) if needles is None: print("[WARN] No needle found in kwargs") return [0.0] * len(completions) # All possible needle prefixes needle_prefixes = [ "The secret code is", "The magic number is", "The password is", "The answer is", "The key value is", "The hidden word is", "The special ID is", "The unique code is", "The mystery number is", "The secret phrase is" ] rewards = [] for i, (completion, truth) in enumerate(zip(completions, needles)): # Handle list format if isinstance(completion, list): text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion]) else: text = str(completion) # Extract prediction pred = extract_needle(text, needle_prefixes) # Ground truth if isinstance(truth, list): truth_val = truth[0] if truth else "" else: truth_val = str(truth) # Compare (case-insensitive, strip whitespace) is_correct = pred.lower().strip() == truth_val.lower().strip() reward = 1.0 if is_correct else 0.0 rewards.append(reward) if i < 2: # Debug first 2 print(f" [{i}] pred='{pred}' truth='{truth_val}' r={reward}") print(f"[REWARD] Avg: {sum(rewards)/len(rewards):.2f} | {rewards[:4]}") return rewards # === LOAD MODEL === print(f"\nšŸ“¦ Loading model: {MODEL_NAME}") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") if device == "cuda": quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=quantization_config, device_map="auto", trust_remote_code=True, ) else: model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("āœ… Model loaded") # === PREPARE FOR TRAINING === if device == "cuda": model = prepare_model_for_kbit_training(model) lora_config = LoraConfig( r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # === GRPO CONFIG === training_args = GRPOConfig( output_dir="./output", max_steps=MAX_STEPS, per_device_train_batch_size=BATCH_SIZE, gradient_accumulation_steps=2, num_generations=GROUP_SIZE, max_completion_length=MAX_COMPLETION_LENGTH, learning_rate=LEARNING_RATE, beta=0.0, # No KL for this task bf16=device == "cuda", fp16=False, gradient_checkpointing=True, optim="adamw_torch_fused" if device == "cuda" else "adamw_torch", logging_steps=1, save_steps=10, dataloader_num_workers=0, remove_unused_columns=False, # PUSH TO HUB! push_to_hub=True, hub_model_id="mindchain/qwen3-0.6b-rlm-needle", ) # === TRAINING LOG === training_log = { "model": MODEL_NAME, "task": "needle_in_haystack", "steps": MAX_STEPS, "start_time": datetime.now().isoformat(), "rewards": [], "losses": [], } # === TRAIN === print(f"\nšŸš€ Starting RLM training for {MAX_STEPS} steps...") print("-"*70) trainer = GRPOTrainer( model=model, args=training_args, train_dataset=dataset, reward_funcs=reward_func, processing_class=tokenizer, ) try: trainer.train() training_log["end_time"] = datetime.now().isoformat() training_log["status"] = "completed" with open("training_log.json", "w") as f: json.dump(training_log, f, indent=2) print("\n" + "="*70) print("āœ… RLM Training Complete!") print("="*70) print(f"Log saved to: training_log.json") except Exception as e: print(f"\nāŒ Training failed: {e}") training_log["status"] = "failed" training_log["error"] = str(e) with open("training_log.json", "w") as f: json.dump(training_log, f, indent=2) raise