| |
| """ |
| GRPO Training - The Reasoning Magic |
| Uses the trained model from stage 1 |
| """ |
|
|
| import sys |
| import json |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from pathlib import Path |
| from tqdm import tqdm |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from src.shorekeeper import SHOREKEEPER, MemoryEfficientSHOREKEEPER |
| from transformers import AutoTokenizer |
|
|
| class GRPOTrainer: |
| """Group Relative Policy Optimization Trainer""" |
| |
| def __init__(self, model, tokenizer, config): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.device = next(model.parameters()).device |
| |
| self.group_size = config.get('group_size', 2) |
| self.lr = config.get('learning_rate', 1e-6) |
| |
| self.optimizer = torch.optim.AdamW( |
| self.model.parameters(), |
| lr=self.lr, |
| weight_decay=0.01 |
| ) |
| |
| self.step = 0 |
| |
| def compute_reward(self, response, ground_truth): |
| """Calculate reward for a response""" |
| reward = 0.0 |
| |
| |
| if '|special_token|' in response: |
| reward += 0.5 |
| |
| |
| import re |
| numbers = re.findall(r'\d+', response) |
| if numbers: |
| last_num = numbers[-1] |
| if last_num == str(ground_truth).strip(): |
| reward += 2.0 |
| |
| |
| if len(response.split()) > 10: |
| reward += 0.2 |
| |
| |
| words = response.split() |
| unique_ratio = len(set(words)) / max(len(words), 1) |
| if unique_ratio > 0.5: |
| reward += 0.3 |
| |
| return reward |
| |
| def generate_response(self, prompt, max_length=128): |
| """Generate a response from the model""" |
| self.model.eval() |
| |
| try: |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| inputs['input_ids'], |
| max_new_tokens=max_length, |
| temperature=0.8, |
| do_sample=True, |
| pad_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return response |
| except Exception as e: |
| return f"Error: {e}" |
| |
| def train_step(self, prompt, ground_truth): |
| """Single GRPO step""" |
| self.model.train() |
| |
| |
| responses = [] |
| rewards = [] |
| |
| for _ in range(self.group_size): |
| response = self.generate_response(prompt) |
| responses.append(response) |
| reward = self.compute_reward(response, ground_truth) |
| rewards.append(reward) |
| |
| |
| mean_reward = sum(rewards) / len(rewards) |
| advantages = [r - mean_reward for r in rewards] |
| |
| |
| total_loss = 0 |
| valid_steps = 0 |
| |
| for i, (response, advantage) in enumerate(zip(responses, advantages)): |
| if advantage <= 0: |
| continue |
| |
| |
| text = f"{prompt}\n{response}" |
| inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| |
| logits = self.model(inputs['input_ids']) |
| |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = inputs['input_ids'][..., 1:].contiguous() |
| |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=self.tokenizer.pad_token_id |
| ) |
| |
| |
| total_loss = total_loss + loss * advantage |
| valid_steps += 1 |
| |
| if valid_steps > 0 and total_loss != 0: |
| total_loss = total_loss / valid_steps |
| self.optimizer.zero_grad() |
| total_loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| self.optimizer.step() |
| return { |
| 'loss': total_loss.item(), |
| 'avg_reward': sum(rewards) / len(rewards), |
| 'best_reward': max(rewards), |
| 'valid_steps': valid_steps |
| } |
| |
| return { |
| 'loss': 0, |
| 'avg_reward': sum(rewards) / len(rewards), |
| 'best_reward': max(rewards), |
| 'valid_steps': 0 |
| } |
| |
| def train(self, dataset, num_epochs=1): |
| """Full training loop""" |
| print(f"\nTraining on device: {self.device}") |
| |
| for epoch in range(num_epochs): |
| print(f"\n{'='*50}") |
| print(f"Epoch {epoch + 1}/{num_epochs}") |
| print(f"{'='*50}") |
| |
| total_loss = 0 |
| total_reward = 0 |
| steps = 0 |
| valid_steps = 0 |
| |
| pbar = tqdm(dataset, desc=f"GRPO Training") |
| |
| for i, item in enumerate(pbar): |
| prompt = item.get('prompt', '') |
| answer = item.get('answer', item.get('ground_truth', '')) |
| |
| if not prompt or not answer: |
| continue |
| |
| try: |
| stats = self.train_step(prompt, str(answer)) |
| |
| if stats['valid_steps'] > 0: |
| total_loss += stats['loss'] |
| valid_steps += 1 |
| |
| total_reward += stats['avg_reward'] |
| steps += 1 |
| |
| pbar.set_postfix({ |
| 'loss': f'{stats["loss"]:.4f}', |
| 'reward': f'{stats["avg_reward"]:.2f}' |
| }) |
| |
| except Exception as e: |
| if i < 10: |
| print(f"\n Error: {e}") |
| continue |
| |
| if steps > 0: |
| avg_loss = total_loss / valid_steps if valid_steps > 0 else 0 |
| avg_reward = total_reward / steps |
| print(f"\n Epoch complete: Avg Loss={avg_loss:.4f}, Avg Reward={avg_reward:.2f}") |
| |
| return self.model |
|
|
| def load_training_data(data_path, limit=None): |
| """Load training data for GRPO""" |
| data = [] |
| data_path = Path(data_path) |
| |
| if not data_path.exists(): |
| print(f"Data file not found: {data_path}") |
| return data |
| |
| with open(data_path, 'r') as f: |
| for i, line in enumerate(f): |
| if limit and i >= limit: |
| break |
| try: |
| item = json.loads(line) |
| data.append({ |
| 'prompt': item.get('prompt', ''), |
| 'answer': item.get('ground_truth', item.get('response', '')) |
| }) |
| except: |
| continue |
| |
| return data |
|
|
| def main(): |
| print("=" * 60) |
| print("SHOREKEEPER GRPO Training") |
| print("The Reasoning Magic") |
| print("=" * 60) |
| |
| |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| print(f"\n✓ CUDA: {torch.cuda.get_device_name(0)}") |
| print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| |
| |
| print("\n1. Loading trained SHOREKEEPER model...") |
| model_path = Path("./outputs/shorekeeper-4b-final.pt") |
| |
| if not model_path.exists(): |
| print(f"\n❌ Model not found at {model_path}") |
| print(" Run training first: python3 scripts/04_train.py") |
| return |
| |
| model = SHOREKEEPER() |
| model.load_state_dict(torch.load(model_path, map_location=device)) |
| model = model.to(device) |
| model.train() |
| print(f" ✓ Model loaded from {model_path}") |
| |
| |
| print("\n2. Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
| print(" ✓ Using GPT-2 tokenizer") |
| |
| |
| print("\n3. Loading training data...") |
| data_path = Path("./data/processed/train.jsonl") |
| |
| if not data_path.exists(): |
| print(f"\n❌ No data at {data_path}") |
| return |
| |
| print(" Options:") |
| print(" [1] Quick test (20 examples)") |
| print(" [2] Small training (100 examples, 3 epochs)") |
| |
| choice = input("\nChoose option (1/2): ").strip() |
| |
| if choice == "1": |
| limit = 20 |
| epochs = 1 |
| else: |
| limit = 100 |
| epochs = 3 |
| |
| data = load_training_data(data_path, limit=limit) |
| print(f"\n Loaded {len(data)} examples") |
| print(f" Training for {epochs} epochs") |
| |
| |
| config = { |
| 'group_size': 2, |
| 'learning_rate': 1e-6 |
| } |
| |
| print("\n4. Initializing GRPO Trainer...") |
| trainer = GRPOTrainer(model, tokenizer, config) |
| |
| print("\n5. Starting GRPO training...") |
| print(" (This teaches the model to reason)\n") |
| |
| try: |
| trained_model = trainer.train(data, num_epochs=epochs) |
| except KeyboardInterrupt: |
| print("\n Interrupted") |
| except Exception as e: |
| print(f"\n Error: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| |
| print("\n6. Saving model...") |
| output_dir = Path("./outputs/grpo") |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| torch.save(model.state_dict(), output_dir / "shorekeeper-4b-grpo.pt") |
| print(f" ✓ Saved to {output_dir / 'shorekeeper-4b-grpo.pt'}") |
| |
| print("\n" + "=" * 60) |
| print("✅ GRPO Complete!") |
| print("=" * 60) |
| print("\nNow run SHOREKEEPER:") |
| print(" python3 scripts/07_run_shorekeeper.py") |
|
|
| if __name__ == "__main__": |
| main() |
|
|