#!/usr/bin/env python3 """ Optimized training for RTX 5090 with 129GB RAM Larger batch sizes = faster training! """ import sys import json import torch import torch.nn as nn from pathlib import Path from tqdm import tqdm import random sys.path.insert(0, str(Path(__file__).parent.parent)) from src.shorekeeper import SHOREKEEPER from transformers import AutoTokenizer def main(): print("=" * 80) print("SHOREKEEPER TRAINING - OPTIMIZED FOR 129GB RAM") print("=" * 80) device = torch.device("cuda") # With 129GB RAM, we can use larger batch sizes! batch_size = 8 # Double from 4 gradient_accumulation = 4 # Half from 8 effective_batch = batch_size * gradient_accumulation # 32 (same effective) print(f"\nGPU: {torch.cuda.get_device_name(0)}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") print(f"System RAM: {psutil.virtual_memory().total / 1e9:.1f} GB") print(f"Batch size: {batch_size}") print(f"Gradient accumulation: {gradient_accumulation}") print(f"Effective batch size: {effective_batch}") # Load model print("\n1. Loading SHOREKEEPER model...") model = SHOREKEEPER() model = model.to(device) params = sum(p.numel() for p in model.parameters()) print(f" Parameters: {params:,} ({params/1e9:.1f}B)") # Load tokenizer print("\n2. Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token tokenizer.model_max_length = 1024 # Load data print("\n3. Loading training data...") data_path = Path("./data/7b_150gb/7b_train.jsonl") if not data_path.exists(): print(" āŒ No data found! Run download script first.") return data = [] with open(data_path, 'r') as f: for line in f: data.append(json.loads(line)) print(f" Loaded {len(data):,} examples") # Optimizer optimizer = torch.optim.AdamW( model.parameters(), lr=3e-4, weight_decay=0.1, betas=(0.9, 0.95) ) scaler = torch.amp.GradScaler('cuda') print("\n4. Starting training...") print(" Training will take 1-2 weeks") epochs = 3 for epoch in range(epochs): print(f"\nEpoch {epoch + 1}/{epochs}") random.shuffle(data) total_loss = 0 steps = 0 optimizer.zero_grad() pbar = tqdm(data, desc=f"Training") for i, item in enumerate(pbar): text = item.get('text', '') if not text or len(text) < 50: continue inputs = tokenizer( text[:2048], return_tensors="pt", truncation=True, max_length=1024, padding="max_length" ) input_ids = inputs['input_ids'].to(device) with torch.autocast(device_type='cuda', dtype=torch.bfloat16): logits = model(input_ids) shift_logits = logits[..., :-1, :].contiguous() shift_labels = input_ids[..., 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id ) scaler.scale(loss).backward() total_loss += loss.item() steps += 1 if (i + 1) % gradient_accumulation == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() optimizer.zero_grad() pbar.set_postfix({ 'loss': f'{loss.item():.4f}', 'avg': f'{total_loss/steps:.4f}' }) if steps % 5000 == 0: torch.save(model.state_dict(), f"./outputs/checkpoint_step_{steps}.pt") print(f"\n šŸ’¾ Checkpoint saved") avg_loss = total_loss / steps print(f"\nEpoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}") torch.save(model.state_dict(), f"./outputs/epoch_{epoch+1}.pt") torch.save(model.state_dict(), "./outputs/shorekeeper_7b_final.pt") print("\nāœ… Training complete!") if __name__ == "__main__": import psutil main()