#!/usr/bin/env python3 """ GRPO Training - The Reasoning Magic Uses the trained model from stage 1 """ import sys import json import torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path from tqdm import tqdm sys.path.insert(0, str(Path(__file__).parent.parent)) from src.shorekeeper import SHOREKEEPER, MemoryEfficientSHOREKEEPER from transformers import AutoTokenizer class GRPOTrainer: """Group Relative Policy Optimization Trainer""" def __init__(self, model, tokenizer, config): self.model = model self.tokenizer = tokenizer self.device = next(model.parameters()).device self.group_size = config.get('group_size', 2) self.lr = config.get('learning_rate', 1e-6) self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.lr, weight_decay=0.01 ) self.step = 0 def compute_reward(self, response, ground_truth): """Calculate reward for a response""" reward = 0.0 # Format reward - check for reasoning tokens if '|special_token|' in response: reward += 0.5 # Extract answer (look for numbers at the end) import re numbers = re.findall(r'\d+', response) if numbers: last_num = numbers[-1] if last_num == str(ground_truth).strip(): reward += 2.0 # Length reward - not too short if len(response.split()) > 10: reward += 0.2 # No repetition penalty words = response.split() unique_ratio = len(set(words)) / max(len(words), 1) if unique_ratio > 0.5: reward += 0.3 return reward def generate_response(self, prompt, max_length=128): """Generate a response from the model""" self.model.eval() try: inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate( inputs['input_ids'], max_new_tokens=max_length, temperature=0.8, do_sample=True, pad_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return response except Exception as e: return f"Error: {e}" def train_step(self, prompt, ground_truth): """Single GRPO step""" self.model.train() # Generate group of responses responses = [] rewards = [] for _ in range(self.group_size): response = self.generate_response(prompt) responses.append(response) reward = self.compute_reward(response, ground_truth) rewards.append(reward) # Calculate advantages (relative to group mean) mean_reward = sum(rewards) / len(rewards) advantages = [r - mean_reward for r in rewards] # Train on responses with positive advantage total_loss = 0 valid_steps = 0 for i, (response, advantage) in enumerate(zip(responses, advantages)): if advantage <= 0: continue # Create training text text = f"{prompt}\n{response}" inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(self.device) for k, v in inputs.items()} # Forward pass logits = self.model(inputs['input_ids']) # Calculate language modeling loss shift_logits = logits[..., :-1, :].contiguous() shift_labels = inputs['input_ids'][..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=self.tokenizer.pad_token_id ) # Weight by advantage total_loss = total_loss + loss * advantage valid_steps += 1 if valid_steps > 0 and total_loss != 0: total_loss = total_loss / valid_steps self.optimizer.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() return { 'loss': total_loss.item(), 'avg_reward': sum(rewards) / len(rewards), 'best_reward': max(rewards), 'valid_steps': valid_steps } return { 'loss': 0, 'avg_reward': sum(rewards) / len(rewards), 'best_reward': max(rewards), 'valid_steps': 0 } def train(self, dataset, num_epochs=1): """Full training loop""" print(f"\nTraining on device: {self.device}") for epoch in range(num_epochs): print(f"\n{'='*50}") print(f"Epoch {epoch + 1}/{num_epochs}") print(f"{'='*50}") total_loss = 0 total_reward = 0 steps = 0 valid_steps = 0 pbar = tqdm(dataset, desc=f"GRPO Training") for i, item in enumerate(pbar): prompt = item.get('prompt', '') answer = item.get('answer', item.get('ground_truth', '')) if not prompt or not answer: continue try: stats = self.train_step(prompt, str(answer)) if stats['valid_steps'] > 0: total_loss += stats['loss'] valid_steps += 1 total_reward += stats['avg_reward'] steps += 1 pbar.set_postfix({ 'loss': f'{stats["loss"]:.4f}', 'reward': f'{stats["avg_reward"]:.2f}' }) except Exception as e: if i < 10: print(f"\n Error: {e}") continue if steps > 0: avg_loss = total_loss / valid_steps if valid_steps > 0 else 0 avg_reward = total_reward / steps print(f"\n Epoch complete: Avg Loss={avg_loss:.4f}, Avg Reward={avg_reward:.2f}") return self.model def load_training_data(data_path, limit=None): """Load training data for GRPO""" data = [] data_path = Path(data_path) if not data_path.exists(): print(f"Data file not found: {data_path}") return data with open(data_path, 'r') as f: for i, line in enumerate(f): if limit and i >= limit: break try: item = json.loads(line) data.append({ 'prompt': item.get('prompt', ''), 'answer': item.get('ground_truth', item.get('response', '')) }) except: continue return data def main(): print("=" * 60) print("SHOREKEEPER GRPO Training") print("The Reasoning Magic") print("=" * 60) # Check device if torch.cuda.is_available(): device = torch.device("cuda") print(f"\n✓ CUDA: {torch.cuda.get_device_name(0)}") print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") # Load trained model (full precision for training) print("\n1. Loading trained SHOREKEEPER model...") model_path = Path("./outputs/shorekeeper-4b-final.pt") if not model_path.exists(): print(f"\n❌ Model not found at {model_path}") print(" Run training first: python3 scripts/04_train.py") return model = SHOREKEEPER() # Use full model (not memory efficient for training) model.load_state_dict(torch.load(model_path, map_location=device)) model = model.to(device) model.train() print(f" ✓ Model loaded from {model_path}") # Load tokenizer print("\n2. Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token print(" ✓ Using GPT-2 tokenizer") # Load training data print("\n3. Loading training data...") data_path = Path("./data/processed/train.jsonl") if not data_path.exists(): print(f"\n❌ No data at {data_path}") return print(" Options:") print(" [1] Quick test (20 examples)") print(" [2] Small training (100 examples, 3 epochs)") choice = input("\nChoose option (1/2): ").strip() if choice == "1": limit = 20 epochs = 1 else: limit = 100 epochs = 3 data = load_training_data(data_path, limit=limit) print(f"\n Loaded {len(data)} examples") print(f" Training for {epochs} epochs") # GRPO config config = { 'group_size': 2, 'learning_rate': 1e-6 } print("\n4. Initializing GRPO Trainer...") trainer = GRPOTrainer(model, tokenizer, config) print("\n5. Starting GRPO training...") print(" (This teaches the model to reason)\n") try: trained_model = trainer.train(data, num_epochs=epochs) except KeyboardInterrupt: print("\n Interrupted") except Exception as e: print(f"\n Error: {e}") import traceback traceback.print_exc() # Save model print("\n6. Saving model...") output_dir = Path("./outputs/grpo") output_dir.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), output_dir / "shorekeeper-4b-grpo.pt") print(f" ✓ Saved to {output_dir / 'shorekeeper-4b-grpo.pt'}") print("\n" + "=" * 60) print("✅ GRPO Complete!") print("=" * 60) print("\nNow run SHOREKEEPER:") print(" python3 scripts/07_run_shorekeeper.py") if __name__ == "__main__": main()