SHOREKEEPER / scripts /05_grpo_train.py
geoore's picture
Restructure to src/ layout with attention, per-layer MoE, and working chat
73400c8
#!/usr/bin/env python3
"""
GRPO Training - The Reasoning Magic
Uses the trained model from stage 1
"""
import sys
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.shorekeeper import SHOREKEEPER, MemoryEfficientSHOREKEEPER
from transformers import AutoTokenizer
class GRPOTrainer:
"""Group Relative Policy Optimization Trainer"""
def __init__(self, model, tokenizer, config):
self.model = model
self.tokenizer = tokenizer
self.device = next(model.parameters()).device
self.group_size = config.get('group_size', 2)
self.lr = config.get('learning_rate', 1e-6)
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.lr,
weight_decay=0.01
)
self.step = 0
def compute_reward(self, response, ground_truth):
"""Calculate reward for a response"""
reward = 0.0
# Format reward - check for reasoning tokens
if '|special_token|' in response:
reward += 0.5
# Extract answer (look for numbers at the end)
import re
numbers = re.findall(r'\d+', response)
if numbers:
last_num = numbers[-1]
if last_num == str(ground_truth).strip():
reward += 2.0
# Length reward - not too short
if len(response.split()) > 10:
reward += 0.2
# No repetition penalty
words = response.split()
unique_ratio = len(set(words)) / max(len(words), 1)
if unique_ratio > 0.5:
reward += 0.3
return reward
def generate_response(self, prompt, max_length=128):
"""Generate a response from the model"""
self.model.eval()
try:
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
inputs['input_ids'],
max_new_tokens=max_length,
temperature=0.8,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
return f"Error: {e}"
def train_step(self, prompt, ground_truth):
"""Single GRPO step"""
self.model.train()
# Generate group of responses
responses = []
rewards = []
for _ in range(self.group_size):
response = self.generate_response(prompt)
responses.append(response)
reward = self.compute_reward(response, ground_truth)
rewards.append(reward)
# Calculate advantages (relative to group mean)
mean_reward = sum(rewards) / len(rewards)
advantages = [r - mean_reward for r in rewards]
# Train on responses with positive advantage
total_loss = 0
valid_steps = 0
for i, (response, advantage) in enumerate(zip(responses, advantages)):
if advantage <= 0:
continue
# Create training text
text = f"{prompt}\n{response}"
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Forward pass
logits = self.model(inputs['input_ids'])
# Calculate language modeling loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = inputs['input_ids'][..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=self.tokenizer.pad_token_id
)
# Weight by advantage
total_loss = total_loss + loss * advantage
valid_steps += 1
if valid_steps > 0 and total_loss != 0:
total_loss = total_loss / valid_steps
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
return {
'loss': total_loss.item(),
'avg_reward': sum(rewards) / len(rewards),
'best_reward': max(rewards),
'valid_steps': valid_steps
}
return {
'loss': 0,
'avg_reward': sum(rewards) / len(rewards),
'best_reward': max(rewards),
'valid_steps': 0
}
def train(self, dataset, num_epochs=1):
"""Full training loop"""
print(f"\nTraining on device: {self.device}")
for epoch in range(num_epochs):
print(f"\n{'='*50}")
print(f"Epoch {epoch + 1}/{num_epochs}")
print(f"{'='*50}")
total_loss = 0
total_reward = 0
steps = 0
valid_steps = 0
pbar = tqdm(dataset, desc=f"GRPO Training")
for i, item in enumerate(pbar):
prompt = item.get('prompt', '')
answer = item.get('answer', item.get('ground_truth', ''))
if not prompt or not answer:
continue
try:
stats = self.train_step(prompt, str(answer))
if stats['valid_steps'] > 0:
total_loss += stats['loss']
valid_steps += 1
total_reward += stats['avg_reward']
steps += 1
pbar.set_postfix({
'loss': f'{stats["loss"]:.4f}',
'reward': f'{stats["avg_reward"]:.2f}'
})
except Exception as e:
if i < 10:
print(f"\n Error: {e}")
continue
if steps > 0:
avg_loss = total_loss / valid_steps if valid_steps > 0 else 0
avg_reward = total_reward / steps
print(f"\n Epoch complete: Avg Loss={avg_loss:.4f}, Avg Reward={avg_reward:.2f}")
return self.model
def load_training_data(data_path, limit=None):
"""Load training data for GRPO"""
data = []
data_path = Path(data_path)
if not data_path.exists():
print(f"Data file not found: {data_path}")
return data
with open(data_path, 'r') as f:
for i, line in enumerate(f):
if limit and i >= limit:
break
try:
item = json.loads(line)
data.append({
'prompt': item.get('prompt', ''),
'answer': item.get('ground_truth', item.get('response', ''))
})
except:
continue
return data
def main():
print("=" * 60)
print("SHOREKEEPER GRPO Training")
print("The Reasoning Magic")
print("=" * 60)
# Check device
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"\n✓ CUDA: {torch.cuda.get_device_name(0)}")
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# Load trained model (full precision for training)
print("\n1. Loading trained SHOREKEEPER model...")
model_path = Path("./outputs/shorekeeper-4b-final.pt")
if not model_path.exists():
print(f"\n❌ Model not found at {model_path}")
print(" Run training first: python3 scripts/04_train.py")
return
model = SHOREKEEPER() # Use full model (not memory efficient for training)
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.train()
print(f" ✓ Model loaded from {model_path}")
# Load tokenizer
print("\n2. Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
print(" ✓ Using GPT-2 tokenizer")
# Load training data
print("\n3. Loading training data...")
data_path = Path("./data/processed/train.jsonl")
if not data_path.exists():
print(f"\n❌ No data at {data_path}")
return
print(" Options:")
print(" [1] Quick test (20 examples)")
print(" [2] Small training (100 examples, 3 epochs)")
choice = input("\nChoose option (1/2): ").strip()
if choice == "1":
limit = 20
epochs = 1
else:
limit = 100
epochs = 3
data = load_training_data(data_path, limit=limit)
print(f"\n Loaded {len(data)} examples")
print(f" Training for {epochs} epochs")
# GRPO config
config = {
'group_size': 2,
'learning_rate': 1e-6
}
print("\n4. Initializing GRPO Trainer...")
trainer = GRPOTrainer(model, tokenizer, config)
print("\n5. Starting GRPO training...")
print(" (This teaches the model to reason)\n")
try:
trained_model = trainer.train(data, num_epochs=epochs)
except KeyboardInterrupt:
print("\n Interrupted")
except Exception as e:
print(f"\n Error: {e}")
import traceback
traceback.print_exc()
# Save model
print("\n6. Saving model...")
output_dir = Path("./outputs/grpo")
output_dir.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), output_dir / "shorekeeper-4b-grpo.pt")
print(f" ✓ Saved to {output_dir / 'shorekeeper-4b-grpo.pt'}")
print("\n" + "=" * 60)
print("✅ GRPO Complete!")
print("=" * 60)
print("\nNow run SHOREKEEPER:")
print(" python3 scripts/07_run_shorekeeper.py")
if __name__ == "__main__":
main()