| |
| """ |
| Optimized training for RTX 5090 with 129GB RAM |
| Larger batch sizes = faster training! |
| """ |
|
|
| import sys |
| import json |
| import torch |
| import torch.nn as nn |
| from pathlib import Path |
| from tqdm import tqdm |
| import random |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from src.shorekeeper import SHOREKEEPER |
| from transformers import AutoTokenizer |
|
|
| def main(): |
| print("=" * 80) |
| print("SHOREKEEPER TRAINING - OPTIMIZED FOR 129GB RAM") |
| print("=" * 80) |
| |
| device = torch.device("cuda") |
| |
| |
| batch_size = 8 |
| gradient_accumulation = 4 |
| effective_batch = batch_size * gradient_accumulation |
| |
| print(f"\nGPU: {torch.cuda.get_device_name(0)}") |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| print(f"System RAM: {psutil.virtual_memory().total / 1e9:.1f} GB") |
| print(f"Batch size: {batch_size}") |
| print(f"Gradient accumulation: {gradient_accumulation}") |
| print(f"Effective batch size: {effective_batch}") |
| |
| |
| print("\n1. Loading SHOREKEEPER model...") |
| model = SHOREKEEPER() |
| model = model.to(device) |
| |
| params = sum(p.numel() for p in model.parameters()) |
| print(f" Parameters: {params:,} ({params/1e9:.1f}B)") |
| |
| |
| print("\n2. Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.model_max_length = 1024 |
| |
| |
| print("\n3. Loading training data...") |
| data_path = Path("./data/7b_150gb/7b_train.jsonl") |
| |
| if not data_path.exists(): |
| print(" ❌ No data found! Run download script first.") |
| return |
| |
| data = [] |
| with open(data_path, 'r') as f: |
| for line in f: |
| data.append(json.loads(line)) |
| |
| print(f" Loaded {len(data):,} examples") |
| |
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=3e-4, |
| weight_decay=0.1, |
| betas=(0.9, 0.95) |
| ) |
| |
| scaler = torch.amp.GradScaler('cuda') |
| |
| print("\n4. Starting training...") |
| print(" Training will take 1-2 weeks") |
| |
| epochs = 3 |
| for epoch in range(epochs): |
| print(f"\nEpoch {epoch + 1}/{epochs}") |
| |
| random.shuffle(data) |
| total_loss = 0 |
| steps = 0 |
| optimizer.zero_grad() |
| |
| pbar = tqdm(data, desc=f"Training") |
| |
| for i, item in enumerate(pbar): |
| text = item.get('text', '') |
| if not text or len(text) < 50: |
| continue |
| |
| inputs = tokenizer( |
| text[:2048], |
| return_tensors="pt", |
| truncation=True, |
| max_length=1024, |
| padding="max_length" |
| ) |
| input_ids = inputs['input_ids'].to(device) |
| |
| with torch.autocast(device_type='cuda', dtype=torch.bfloat16): |
| logits = model(input_ids) |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = input_ids[..., 1:].contiguous() |
| loss = nn.functional.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=tokenizer.pad_token_id |
| ) |
| |
| scaler.scale(loss).backward() |
| |
| total_loss += loss.item() |
| steps += 1 |
| |
| if (i + 1) % gradient_accumulation == 0: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad() |
| |
| pbar.set_postfix({ |
| 'loss': f'{loss.item():.4f}', |
| 'avg': f'{total_loss/steps:.4f}' |
| }) |
| |
| if steps % 5000 == 0: |
| torch.save(model.state_dict(), f"./outputs/checkpoint_step_{steps}.pt") |
| print(f"\n 💾 Checkpoint saved") |
| |
| avg_loss = total_loss / steps |
| print(f"\nEpoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}") |
| torch.save(model.state_dict(), f"./outputs/epoch_{epoch+1}.pt") |
| |
| torch.save(model.state_dict(), "./outputs/shorekeeper_7b_final.pt") |
| print("\n✅ Training complete!") |
|
|
| if __name__ == "__main__": |
| import psutil |
| main() |
|
|