#!/usr/bin/env python3 """ Clean SHOREKEEPER training on STEM data only """ import sys import json import torch import torch.nn as nn from pathlib import Path from tqdm import tqdm import random sys.path.insert(0, str(Path(__file__).parent.parent)) from src.shorekeeper import SHOREKEEPER from transformers import AutoTokenizer def main(): print("=" * 70) print("SHOREKEEPER - STEM TRAINING") print("=" * 70) # Check device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\nDevice: {device}") # Load model (fresh from scratch) print("\n1. Loading SHOREKEEPER model...") model = SHOREKEEPER() model = model.to(device) print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") # Load tokenizer print("\n2. Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token print(" āœ“ GPT-2 tokenizer") # Load STEM data print("\n3. Loading STEM training data...") data_path = Path("./data/stem/stem_train.jsonl") if not data_path.exists(): print(" āŒ No STEM data found!") print(" Run: python3 scripts/01_download_stem_data.py") return data = [] with open(data_path, 'r') as f: for line in f: data.append(json.loads(line)) print(f" Loaded {len(data):,} examples") # Training config batch_size = 2 gradient_accumulation = 8 learning_rate = 3e-4 optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1) print("\n4. Training configuration:") print(f" Examples: {len(data):,}") print(f" Learning rate: {learning_rate}") print(f" Batch size: {batch_size}") print(f" Gradient accumulation: {gradient_accumulation}") print(f" Effective batch size: {batch_size * gradient_accumulation}") # Training loop epochs = 5 print(f"\n5. Training for {epochs} epochs...") for epoch in range(epochs): print(f"\nEpoch {epoch + 1}/{epochs}") # Shuffle data random.shuffle(data) total_loss = 0 steps = 0 optimizer.zero_grad() pbar = tqdm(data, desc=f"Training") for i, item in enumerate(pbar): # Format text text = f"{item['prompt']}\n{item['response']}" # Tokenize inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) input_ids = inputs['input_ids'].to(device) # Forward logits = model(input_ids) # Loss shift_logits = logits[..., :-1, :].contiguous() shift_labels = input_ids[..., 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id ) # Backward loss.backward() total_loss += loss.item() steps += 1 # Update weights if (i + 1) % gradient_accumulation == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() # Update progress bar pbar.set_postfix({'loss': f'{loss.item():.4f}', 'avg': f'{total_loss/steps:.4f}'}) avg_loss = total_loss / steps print(f" Epoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}") # Save checkpoint torch.save(model.state_dict(), f"./outputs/shorekeeper_stem_epoch_{epoch+1}.pt") print(f" Saved: outputs/shorekeeper_stem_epoch_{epoch+1}.pt") # Final save torch.save(model.state_dict(), "./outputs/shorekeeper_stem_final.pt") print("\nāœ… Training complete!") print(" Final model: outputs/shorekeeper_stem_final.pt") if __name__ == "__main__": main()