Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| GRPO Training for RLM Skills (Recursive Long-Context Reasoning) | |
| Based on RLM Paper - Training model to find needles in haystacks | |
| """ | |
| import os | |
| import re | |
| import json | |
| import random | |
| import string | |
| from datetime import datetime | |
| import torch | |
| from datasets import Dataset | |
| from trl import GRPOTrainer, GRPOConfig | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer | |
| # === CONFIG === | |
| MODEL_NAME = "Qwen/Qwen3-0.6B-Base" | |
| MAX_STEPS = 20 | |
| BATCH_SIZE = 2 | |
| GROUP_SIZE = 4 | |
| LEARNING_RATE = 2e-4 | |
| MAX_COMPLETION_LENGTH = 64 # Short answers for needle finding | |
| # LoRA | |
| LORA_R = 16 | |
| LORA_ALPHA = 32 | |
| print("="*70) | |
| print("🔍 RLM Training - Needle in Haystack") | |
| print("="*70) | |
| print(f"Model: {MODEL_NAME}") | |
| print(f"Steps: {MAX_STEPS}") | |
| print("Task: Find specific facts in long context") | |
| print("="*70) | |
| # === GENERATE SYNTHETIC NEEDLE-IN-HAYSTACK DATA === | |
| def generate_needle_haystack(num_samples=100, context_length=2000, seed=42): | |
| """ | |
| Generate synthetic needle-in-haystack data. | |
| Long context with a hidden fact (needle) to find. | |
| """ | |
| random.seed(seed) | |
| needles = [ | |
| ("The secret code is", "XK7M9"), | |
| ("The magic number is", "42"), | |
| ("The password is", "quantum2026"), | |
| ("The answer is", "17"), | |
| ("The key value is", "alpha-beta-7"), | |
| ("The hidden word is", "serendipity"), | |
| ("The special ID is", "ID-847291"), | |
| ("The unique code is", "ZEBRA-99"), | |
| ("The mystery number is", "314159"), | |
| ("The secret phrase is", "blue moon rising"), | |
| ] | |
| samples = [] | |
| for i in range(num_samples): | |
| # Pick a random needle | |
| prefix, needle = random.choice(needles) | |
| # Generate haystack (random text) | |
| words = [] | |
| for _ in range(context_length // 5): # Approx words | |
| word_len = random.randint(3, 10) | |
| word = ''.join(random.choices(string.ascii_lowercase, k=word_len)) | |
| words.append(word) | |
| haystack = ' '.join(words) | |
| # Insert needle at random position | |
| insert_pos = random.randint(len(haystack) // 4, 3 * len(haystack) // 4) | |
| context = haystack[:insert_pos] + f" {prefix} {needle}. " + haystack[insert_pos:] | |
| # Create prompt | |
| prompt = f"""Find the hidden information in this text. | |
| The text contains a secret piece of information. Find it and report ONLY the value, nothing else. | |
| Text: | |
| {context} | |
| What is the hidden value?""" | |
| samples.append({ | |
| "prompt": prompt, | |
| "needle": needle, | |
| "needle_prefix": prefix, | |
| }) | |
| return samples | |
| print("\n📊 Generating Needle-in-Haystack dataset...") | |
| data = generate_needle_haystack(num_samples=100) | |
| print(f"✅ Generated {len(data)} samples") | |
| # Convert to dataset | |
| dataset_dict = { | |
| "prompt": [s["prompt"] for s in data], | |
| "needle": [s["needle"] for s in data], | |
| "needle_prefix": [s["needle_prefix"] for s in data], | |
| } | |
| dataset = Dataset.from_dict(dataset_dict) | |
| # === REWARD FUNCTION === | |
| def extract_needle(text, needle_prefixes): | |
| """Extract the needle value from model output""" | |
| text = text.strip() | |
| # Try to find pattern like "is X" or just the value | |
| for prefix in needle_prefixes: | |
| pattern = rf"{re.escape(prefix)}\s+(\S+)" | |
| match = re.search(pattern, text, re.IGNORECASE) | |
| if match: | |
| return match.group(1).rstrip('.,;:') | |
| # Fallback: last non-common word | |
| words = text.split() | |
| common_words = {'the', 'is', 'a', 'an', 'it', 'to', 'and', 'or', 'in', 'on', 'at'} | |
| for word in reversed(words): | |
| clean = word.strip('.,;:').lower() | |
| if clean not in common_words and len(clean) > 1: | |
| return word.strip('.,;:') | |
| return text.split()[-1].strip('.,;:') if text else "" | |
| def reward_func(completions, prompts, **kwargs): | |
| """ | |
| RLM Reward Function - rewards finding the correct needle | |
| """ | |
| needles = kwargs.get('needle', None) | |
| if needles is None: | |
| print("[WARN] No needle found in kwargs") | |
| return [0.0] * len(completions) | |
| # All possible needle prefixes | |
| needle_prefixes = [ | |
| "The secret code is", "The magic number is", "The password is", | |
| "The answer is", "The key value is", "The hidden word is", | |
| "The special ID is", "The unique code is", "The mystery number is", | |
| "The secret phrase is" | |
| ] | |
| rewards = [] | |
| for i, (completion, truth) in enumerate(zip(completions, needles)): | |
| # Handle list format | |
| if isinstance(completion, list): | |
| text = " ".join([m.get('content', '') if isinstance(m, dict) else str(m) for m in completion]) | |
| else: | |
| text = str(completion) | |
| # Extract prediction | |
| pred = extract_needle(text, needle_prefixes) | |
| # Ground truth | |
| if isinstance(truth, list): | |
| truth_val = truth[0] if truth else "" | |
| else: | |
| truth_val = str(truth) | |
| # Compare (case-insensitive, strip whitespace) | |
| is_correct = pred.lower().strip() == truth_val.lower().strip() | |
| reward = 1.0 if is_correct else 0.0 | |
| rewards.append(reward) | |
| if i < 2: # Debug first 2 | |
| print(f" [{i}] pred='{pred}' truth='{truth_val}' r={reward}") | |
| print(f"[REWARD] Avg: {sum(rewards)/len(rewards):.2f} | {rewards[:4]}") | |
| return rewards | |
| # === LOAD MODEL === | |
| print(f"\n📦 Loading model: {MODEL_NAME}") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Device: {device}") | |
| if device == "cuda": | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("✅ Model loaded") | |
| # === PREPARE FOR TRAINING === | |
| if device == "cuda": | |
| model = prepare_model_for_kbit_training(model) | |
| lora_config = LoraConfig( | |
| r=LORA_R, | |
| lora_alpha=LORA_ALPHA, | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| # === GRPO CONFIG === | |
| training_args = GRPOConfig( | |
| output_dir="./output", | |
| max_steps=MAX_STEPS, | |
| per_device_train_batch_size=BATCH_SIZE, | |
| gradient_accumulation_steps=2, | |
| num_generations=GROUP_SIZE, | |
| max_completion_length=MAX_COMPLETION_LENGTH, | |
| learning_rate=LEARNING_RATE, | |
| beta=0.0, # No KL for this task | |
| bf16=device == "cuda", | |
| fp16=False, | |
| gradient_checkpointing=True, | |
| optim="adamw_torch_fused" if device == "cuda" else "adamw_torch", | |
| logging_steps=1, | |
| save_steps=10, | |
| dataloader_num_workers=0, | |
| remove_unused_columns=False, | |
| # PUSH TO HUB! | |
| push_to_hub=True, | |
| hub_model_id="mindchain/qwen3-0.6b-rlm-needle", | |
| ) | |
| # === TRAINING LOG === | |
| training_log = { | |
| "model": MODEL_NAME, | |
| "task": "needle_in_haystack", | |
| "steps": MAX_STEPS, | |
| "start_time": datetime.now().isoformat(), | |
| "rewards": [], | |
| "losses": [], | |
| } | |
| # === TRAIN === | |
| print(f"\n🚀 Starting RLM training for {MAX_STEPS} steps...") | |
| print("-"*70) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| reward_funcs=reward_func, | |
| processing_class=tokenizer, | |
| ) | |
| try: | |
| trainer.train() | |
| training_log["end_time"] = datetime.now().isoformat() | |
| training_log["status"] = "completed" | |
| with open("training_log.json", "w") as f: | |
| json.dump(training_log, f, indent=2) | |
| print("\n" + "="*70) | |
| print("✅ RLM Training Complete!") | |
| print("="*70) | |
| print(f"Log saved to: training_log.json") | |
| except Exception as e: | |
| print(f"\n❌ Training failed: {e}") | |
| training_log["status"] = "failed" | |
| training_log["error"] = str(e) | |
| with open("training_log.json", "w") as f: | |
| json.dump(training_log, f, indent=2) | |
| raise | |