#!/usr/bin/env python3 """ SHOREKEEPER Universal Training Script Works on: RTX 3060, RTX 5090, H100, A100, Mac MPS, CPU Auto-detects hardware and optimizes accordingly """ import sys import json import torch import torch.nn as nn from pathlib import Path from tqdm import tqdm import random import yaml import platform import psutil sys.path.insert(0, str(Path(__file__).parent.parent)) from src.shorekeeper import SHOREKEEPER from transformers import AutoTokenizer def detect_hardware(): """Auto-detect best available device and optimize settings""" print("\n" + "=" * 70) print("HARDWARE DETECTION") print("=" * 70) # Check CUDA if torch.cuda.is_available(): device = torch.device("cuda") gpu_name = torch.cuda.get_device_name(0) gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9 cuda_version = torch.version.cuda print(f"āœ“ CUDA GPU: {gpu_name}") print(f" Memory: {gpu_mem:.1f} GB") print(f" CUDA Version: {cuda_version}") # Optimize batch size based on GPU memory if gpu_mem >= 80: # H100/A100 recommended_batch = 8 recommended_accum = 4 precision = "bfloat16" elif gpu_mem >= 32: # RTX 5090, A6000 recommended_batch = 4 recommended_accum = 8 precision = "bfloat16" elif gpu_mem >= 16: # RTX 4080, 4090 recommended_batch = 2 recommended_accum = 8 precision = "float16" elif gpu_mem >= 12: # RTX 3060, 3070, 3080 recommended_batch = 1 recommended_accum = 16 precision = "float16" else: recommended_batch = 1 recommended_accum = 32 precision = "float16" # Check Apple Metal (M1/M2/M3 Macs) elif torch.backends.mps.is_available(): device = torch.device("mps") print("āœ“ Apple Metal (M1/M2/M3) detected") recommended_batch = 2 recommended_accum = 4 precision = "float16" print(" Note: MPS support is experimental, may need torch nightly") # Fallback to CPU else: device = torch.device("cpu") print("⚠ No GPU detected, using CPU (will be very slow)") recommended_batch = 1 recommended_accum = 1 precision = "float32" # Show CPU info cpu_count = psutil.cpu_count() ram = psutil.virtual_memory().total / 1e9 print(f" CPU: {cpu_count} cores") print(f" RAM: {ram:.1f} GB") print(f"\nRecommended settings:") print(f" Batch size: {recommended_batch}") print(f" Gradient accumulation: {recommended_accum}") print(f" Effective batch size: {recommended_batch * recommended_accum}") print(f" Precision: {precision}") return { 'device': device, 'batch_size': recommended_batch, 'gradient_accumulation': recommended_accum, 'precision': precision, 'gpu_memory': gpu_mem if torch.cuda.is_available() else 0 } def get_model_size(model): """Calculate model size in billions of parameters""" params = sum(p.numel() for p in model.parameters()) return params / 1e9 class UniversalTrainer: """Trainer that adapts to any hardware""" def __init__(self, model, tokenizer, hardware_config): self.model = model self.tokenizer = tokenizer self.device = hardware_config['device'] self.batch_size = hardware_config['batch_size'] self.gradient_accumulation = hardware_config['gradient_accumulation'] self.precision = hardware_config['precision'] # Learning rate scales with model size model_size = get_model_size(model) if model_size < 1: base_lr = 5e-4 elif model_size < 4: base_lr = 3e-4 elif model_size < 8: base_lr = 2e-4 else: base_lr = 1e-4 self.learning_rate = base_lr self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.learning_rate, weight_decay=0.1, betas=(0.9, 0.95) ) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( self.optimizer, T_0=5000, T_mult=2 ) self.step = 0 self.total_loss = 0 # Mixed precision training self.scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None print(f"\nTraining configuration:") print(f" Device: {self.device}") print(f" Learning rate: {self.learning_rate}") print(f" Batch size: {self.batch_size}") print(f" Gradient accumulation: {self.gradient_accumulation}") print(f" Precision: {self.precision}") def train_step(self, text): """Single training step with mixed precision""" self.model.train() # Tokenize inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding="max_length" ) input_ids = inputs['input_ids'].to(self.device) # Mixed precision forward pass if self.precision == "bfloat16" and torch.cuda.is_available(): with torch.autocast(device_type='cuda', dtype=torch.bfloat16): logits = self.model(input_ids) loss = self._compute_loss(logits, input_ids) elif self.precision == "float16" and torch.cuda.is_available(): with torch.autocast(device_type='cuda', dtype=torch.float16): logits = self.model(input_ids) loss = self._compute_loss(logits, input_ids) else: logits = self.model(input_ids) loss = self._compute_loss(logits, input_ids) # Backward with gradient scaling if using fp16 if self.scaler: self.scaler.scale(loss).backward() else: loss.backward() # Gradient accumulation and optimizer step if (self.step + 1) % self.gradient_accumulation == 0: if self.scaler: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.scaler.step(self.optimizer) self.scaler.update() else: torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() self.step += 1 return loss.item() def _compute_loss(self, logits, input_ids): """Compute cross-entropy loss""" shift_logits = logits[..., :-1, :].contiguous() shift_labels = input_ids[..., 1:].contiguous() return nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=self.tokenizer.pad_token_id ) def train(self, data, num_epochs=1, save_every=5000): """Full training loop""" print(f"\n{'='*70}") print(f"STARTING TRAINING") print(f"{'='*70}") print(f"Examples: {len(data):,}") print(f"Epochs: {num_epochs}") print(f"Save checkpoint every {save_every} steps") for epoch in range(num_epochs): print(f"\nEpoch {epoch + 1}/{num_epochs}") print("-" * 40) # Shuffle data random.shuffle(data) total_loss = 0 steps = 0 self.optimizer.zero_grad() pbar = tqdm(data, desc=f"Training") for i, item in enumerate(pbar): # Get text from item (handles different formats) text = item.get('text', '') if not text: text = f"{item.get('prompt', '')}\n{item.get('response', '')}" if not text or len(text) < 10: continue try: loss = self.train_step(text[:2048]) # Limit length total_loss += loss steps += 1 # Update progress bar avg_loss = total_loss / steps pbar.set_postfix({ 'loss': f'{loss:.4f}', 'avg': f'{avg_loss:.4f}' }) # Save checkpoint if steps % save_every == 0: checkpoint = { 'step': self.step, 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optimizer_state': self.optimizer.state_dict(), 'loss': loss, 'avg_loss': avg_loss } torch.save(checkpoint, f"./outputs/checkpoint_step_{self.step}.pt") print(f"\n šŸ’¾ Checkpoint saved at step {self.step}") except Exception as e: if steps < 10: # Only print first few errors print(f"\n ⚠ Error: {e}") continue avg_loss = total_loss / steps if steps > 0 else 0 print(f"\nEpoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}") # Save epoch checkpoint torch.save({ 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optimizer_state': self.optimizer.state_dict(), 'avg_loss': avg_loss }, f"./outputs/epoch_{epoch + 1}.pt") print(f" šŸ’¾ Saved epoch checkpoint") def load_training_data(data_path, max_examples=None): """Load training data from JSONL file""" data = [] data_path = Path(data_path) if not data_path.exists(): return [] with open(data_path, 'r') as f: for i, line in enumerate(f): if max_examples and i >= max_examples: break try: item = json.loads(line) data.append(item) except: continue return data def main(): print("=" * 70) print("SHOREKEEPER UNIVERSAL TRAINING") print="=" * 70) # Detect hardware hw_config = detect_hardware() device = hw_config['device'] # Check model config config_path = "configs/model.yaml" if Path("configs/model_15b.yaml").exists(): print("\nšŸ“ Found 15B config, using that") config_path = "configs/model_15b.yaml" # Load model print("\n1. Loading SHOREKEEPER model...") model = SHOREKEEPER(config_path=config_path) model = model.to(device) model_size = get_model_size(model) print(f" Model size: {model_size:.1f}B parameters") print(f" Memory usage estimate: {model_size * 4:.1f} GB (fp32)") # Load tokenizer print("\n2. Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token tokenizer.model_max_length = 512 print(" āœ“ GPT-2 tokenizer") # Load data print("\n3. Loading training data...") # Try multiple possible data paths data_paths = [ "./data/15b_data/15b_train.jsonl", "./data/stem/stem_train.jsonl", "./data/processed/train_large.jsonl", "./data/processed/train.jsonl" ] data = [] for path in data_paths: if Path(path).exists(): data = load_training_data(path) if data: print(f" āœ“ Loaded {len(data):,} examples from {path}") break if not data: print("\nāŒ No training data found!") print("\nPlease run one of these first:") print(" python3 scripts/01_download_stem_data.py") print(" python3 scripts/01_download_15b_data.py") return # Ask user for training mode print("\n" + "=" * 70) print("TRAINING OPTIONS") print("=" * 70) print(f"1. Quick test (10% of data, 1 epoch)") print(f"2. Standard training (all data, 3 epochs)") print(f"3. Full training (all data, 10 epochs)") print(f"4. Custom (enter your own settings)") choice = input("\nChoose option (1-4): ").strip() if choice == "1": data = data[:max(1000, len(data) // 10)] epochs = 1 elif choice == "2": epochs = 3 elif choice == "3": epochs = 10 elif choice == "4": epochs = int(input("Number of epochs: ").strip()) limit = input("Limit examples (press Enter for all): ").strip() if limit: data = data[:int(limit)] else: epochs = 1 # Create trainer trainer = UniversalTrainer(model, tokenizer, hw_config) # Start training print(f"\n4. Starting training on {len(data):,} examples for {epochs} epochs...") print(" Press Ctrl+C to stop and save checkpoint\n") try: trainer.train(data, num_epochs=epochs) except KeyboardInterrupt: print("\n\n⚠ Training interrupted by user") print("Saving current model...") torch.save(model.state_dict(), "./outputs/shorekeeper_interrupted.pt") print("Model saved to: ./outputs/shorekeeper_interrupted.pt") except Exception as e: print(f"\nāŒ Training error: {e}") import traceback traceback.print_exc() # Final save final_path = "./outputs/shorekeeper_final.pt" torch.save(model.state_dict(), final_path) print(f"\nāœ… Model saved to: {final_path}") print("\n" + "=" * 70) print("NEXT STEPS") print("=" * 70) print("1. Test your model:") print(" python3 scripts/07_run_shorekeeper.py") print("\n2. Convert to 4-bit for inference:") print(" python3 scripts/06_convert_to_4bit.py") print("\n3. Run GRPO reasoning training:") print(" python3 scripts/05_grpo_train.py") if __name__ == "__main__": main()