SHOREKEEPER / scripts /04_train_stem.py
geoore's picture
Restructure to src/ layout with attention, per-layer MoE, and working chat
73400c8
#!/usr/bin/env python3
"""
Clean SHOREKEEPER training on STEM data only
"""
import sys
import json
import torch
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
import random
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.shorekeeper import SHOREKEEPER
from transformers import AutoTokenizer
def main():
print("=" * 70)
print("SHOREKEEPER - STEM TRAINING")
print("=" * 70)
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nDevice: {device}")
# Load model (fresh from scratch)
print("\n1. Loading SHOREKEEPER model...")
model = SHOREKEEPER()
model = model.to(device)
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Load tokenizer
print("\n2. Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
print(" ✓ GPT-2 tokenizer")
# Load STEM data
print("\n3. Loading STEM training data...")
data_path = Path("./data/stem/stem_train.jsonl")
if not data_path.exists():
print(" ❌ No STEM data found!")
print(" Run: python3 scripts/01_download_stem_data.py")
return
data = []
with open(data_path, 'r') as f:
for line in f:
data.append(json.loads(line))
print(f" Loaded {len(data):,} examples")
# Training config
batch_size = 2
gradient_accumulation = 8
learning_rate = 3e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
print("\n4. Training configuration:")
print(f" Examples: {len(data):,}")
print(f" Learning rate: {learning_rate}")
print(f" Batch size: {batch_size}")
print(f" Gradient accumulation: {gradient_accumulation}")
print(f" Effective batch size: {batch_size * gradient_accumulation}")
# Training loop
epochs = 5
print(f"\n5. Training for {epochs} epochs...")
for epoch in range(epochs):
print(f"\nEpoch {epoch + 1}/{epochs}")
# Shuffle data
random.shuffle(data)
total_loss = 0
steps = 0
optimizer.zero_grad()
pbar = tqdm(data, desc=f"Training")
for i, item in enumerate(pbar):
# Format text
text = f"{item['prompt']}\n{item['response']}"
# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
input_ids = inputs['input_ids'].to(device)
# Forward
logits = model(input_ids)
# Loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=tokenizer.pad_token_id
)
# Backward
loss.backward()
total_loss += loss.item()
steps += 1
# Update weights
if (i + 1) % gradient_accumulation == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
# Update progress bar
pbar.set_postfix({'loss': f'{loss.item():.4f}', 'avg': f'{total_loss/steps:.4f}'})
avg_loss = total_loss / steps
print(f" Epoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}")
# Save checkpoint
torch.save(model.state_dict(), f"./outputs/shorekeeper_stem_epoch_{epoch+1}.pt")
print(f" Saved: outputs/shorekeeper_stem_epoch_{epoch+1}.pt")
# Final save
torch.save(model.state_dict(), "./outputs/shorekeeper_stem_final.pt")
print("\n✅ Training complete!")
print(" Final model: outputs/shorekeeper_stem_final.pt")
if __name__ == "__main__":
main()