File size: 4,148 Bytes
73400c8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | #!/usr/bin/env python3
"""
Clean SHOREKEEPER training on STEM data only
"""
import sys
import json
import torch
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
import random
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.shorekeeper import SHOREKEEPER
from transformers import AutoTokenizer
def main():
print("=" * 70)
print("SHOREKEEPER - STEM TRAINING")
print("=" * 70)
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nDevice: {device}")
# Load model (fresh from scratch)
print("\n1. Loading SHOREKEEPER model...")
model = SHOREKEEPER()
model = model.to(device)
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Load tokenizer
print("\n2. Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
print(" ✓ GPT-2 tokenizer")
# Load STEM data
print("\n3. Loading STEM training data...")
data_path = Path("./data/stem/stem_train.jsonl")
if not data_path.exists():
print(" ❌ No STEM data found!")
print(" Run: python3 scripts/01_download_stem_data.py")
return
data = []
with open(data_path, 'r') as f:
for line in f:
data.append(json.loads(line))
print(f" Loaded {len(data):,} examples")
# Training config
batch_size = 2
gradient_accumulation = 8
learning_rate = 3e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
print("\n4. Training configuration:")
print(f" Examples: {len(data):,}")
print(f" Learning rate: {learning_rate}")
print(f" Batch size: {batch_size}")
print(f" Gradient accumulation: {gradient_accumulation}")
print(f" Effective batch size: {batch_size * gradient_accumulation}")
# Training loop
epochs = 5
print(f"\n5. Training for {epochs} epochs...")
for epoch in range(epochs):
print(f"\nEpoch {epoch + 1}/{epochs}")
# Shuffle data
random.shuffle(data)
total_loss = 0
steps = 0
optimizer.zero_grad()
pbar = tqdm(data, desc=f"Training")
for i, item in enumerate(pbar):
# Format text
text = f"{item['prompt']}\n{item['response']}"
# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
input_ids = inputs['input_ids'].to(device)
# Forward
logits = model(input_ids)
# Loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=tokenizer.pad_token_id
)
# Backward
loss.backward()
total_loss += loss.item()
steps += 1
# Update weights
if (i + 1) % gradient_accumulation == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
# Update progress bar
pbar.set_postfix({'loss': f'{loss.item():.4f}', 'avg': f'{total_loss/steps:.4f}'})
avg_loss = total_loss / steps
print(f" Epoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}")
# Save checkpoint
torch.save(model.state_dict(), f"./outputs/shorekeeper_stem_epoch_{epoch+1}.pt")
print(f" Saved: outputs/shorekeeper_stem_epoch_{epoch+1}.pt")
# Final save
torch.save(model.state_dict(), "./outputs/shorekeeper_stem_final.pt")
print("\n✅ Training complete!")
print(" Final model: outputs/shorekeeper_stem_final.pt")
if __name__ == "__main__":
main()
|