SHOREKEEPER / scripts /04_train_5090_optimized.py
geoore's picture
Restructure to src/ layout with attention, per-layer MoE, and working chat
73400c8
#!/usr/bin/env python3
"""
Optimized training for RTX 5090 with 129GB RAM
Larger batch sizes = faster training!
"""
import sys
import json
import torch
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
import random
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.shorekeeper import SHOREKEEPER
from transformers import AutoTokenizer
def main():
print("=" * 80)
print("SHOREKEEPER TRAINING - OPTIMIZED FOR 129GB RAM")
print("=" * 80)
device = torch.device("cuda")
# With 129GB RAM, we can use larger batch sizes!
batch_size = 8 # Double from 4
gradient_accumulation = 4 # Half from 8
effective_batch = batch_size * gradient_accumulation # 32 (same effective)
print(f"\nGPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"System RAM: {psutil.virtual_memory().total / 1e9:.1f} GB")
print(f"Batch size: {batch_size}")
print(f"Gradient accumulation: {gradient_accumulation}")
print(f"Effective batch size: {effective_batch}")
# Load model
print("\n1. Loading SHOREKEEPER model...")
model = SHOREKEEPER()
model = model.to(device)
params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {params:,} ({params/1e9:.1f}B)")
# Load tokenizer
print("\n2. Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 1024
# Load data
print("\n3. Loading training data...")
data_path = Path("./data/7b_150gb/7b_train.jsonl")
if not data_path.exists():
print(" ❌ No data found! Run download script first.")
return
data = []
with open(data_path, 'r') as f:
for line in f:
data.append(json.loads(line))
print(f" Loaded {len(data):,} examples")
# Optimizer
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4,
weight_decay=0.1,
betas=(0.9, 0.95)
)
scaler = torch.amp.GradScaler('cuda')
print("\n4. Starting training...")
print(" Training will take 1-2 weeks")
epochs = 3
for epoch in range(epochs):
print(f"\nEpoch {epoch + 1}/{epochs}")
random.shuffle(data)
total_loss = 0
steps = 0
optimizer.zero_grad()
pbar = tqdm(data, desc=f"Training")
for i, item in enumerate(pbar):
text = item.get('text', '')
if not text or len(text) < 50:
continue
inputs = tokenizer(
text[:2048],
return_tensors="pt",
truncation=True,
max_length=1024,
padding="max_length"
)
input_ids = inputs['input_ids'].to(device)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
logits = model(input_ids)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=tokenizer.pad_token_id
)
scaler.scale(loss).backward()
total_loss += loss.item()
steps += 1
if (i + 1) % gradient_accumulation == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
pbar.set_postfix({
'loss': f'{loss.item():.4f}',
'avg': f'{total_loss/steps:.4f}'
})
if steps % 5000 == 0:
torch.save(model.state_dict(), f"./outputs/checkpoint_step_{steps}.pt")
print(f"\n 💾 Checkpoint saved")
avg_loss = total_loss / steps
print(f"\nEpoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}")
torch.save(model.state_dict(), f"./outputs/epoch_{epoch+1}.pt")
torch.save(model.state_dict(), "./outputs/shorekeeper_7b_final.pt")
print("\n✅ Training complete!")
if __name__ == "__main__":
import psutil
main()