| |
| """ |
| SHOREKEEPER Universal Training Script |
| Works on: RTX 3060, RTX 5090, H100, A100, Mac MPS, CPU |
| Auto-detects hardware and optimizes accordingly |
| """ |
|
|
| import sys |
| import json |
| import torch |
| import torch.nn as nn |
| from pathlib import Path |
| from tqdm import tqdm |
| import random |
| import yaml |
| import platform |
| import psutil |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from src.shorekeeper import SHOREKEEPER |
| from transformers import AutoTokenizer |
|
|
| def detect_hardware(): |
| """Auto-detect best available device and optimize settings""" |
| |
| print("\n" + "=" * 70) |
| print("HARDWARE DETECTION") |
| print("=" * 70) |
| |
| |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| gpu_name = torch.cuda.get_device_name(0) |
| gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9 |
| cuda_version = torch.version.cuda |
| print(f"✓ CUDA GPU: {gpu_name}") |
| print(f" Memory: {gpu_mem:.1f} GB") |
| print(f" CUDA Version: {cuda_version}") |
| |
| |
| if gpu_mem >= 80: |
| recommended_batch = 8 |
| recommended_accum = 4 |
| precision = "bfloat16" |
| elif gpu_mem >= 32: |
| recommended_batch = 4 |
| recommended_accum = 8 |
| precision = "bfloat16" |
| elif gpu_mem >= 16: |
| recommended_batch = 2 |
| recommended_accum = 8 |
| precision = "float16" |
| elif gpu_mem >= 12: |
| recommended_batch = 1 |
| recommended_accum = 16 |
| precision = "float16" |
| else: |
| recommended_batch = 1 |
| recommended_accum = 32 |
| precision = "float16" |
| |
| |
| elif torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| print("✓ Apple Metal (M1/M2/M3) detected") |
| recommended_batch = 2 |
| recommended_accum = 4 |
| precision = "float16" |
| print(" Note: MPS support is experimental, may need torch nightly") |
| |
| |
| else: |
| device = torch.device("cpu") |
| print("⚠ No GPU detected, using CPU (will be very slow)") |
| recommended_batch = 1 |
| recommended_accum = 1 |
| precision = "float32" |
| |
| |
| cpu_count = psutil.cpu_count() |
| ram = psutil.virtual_memory().total / 1e9 |
| print(f" CPU: {cpu_count} cores") |
| print(f" RAM: {ram:.1f} GB") |
| |
| print(f"\nRecommended settings:") |
| print(f" Batch size: {recommended_batch}") |
| print(f" Gradient accumulation: {recommended_accum}") |
| print(f" Effective batch size: {recommended_batch * recommended_accum}") |
| print(f" Precision: {precision}") |
| |
| return { |
| 'device': device, |
| 'batch_size': recommended_batch, |
| 'gradient_accumulation': recommended_accum, |
| 'precision': precision, |
| 'gpu_memory': gpu_mem if torch.cuda.is_available() else 0 |
| } |
|
|
| def get_model_size(model): |
| """Calculate model size in billions of parameters""" |
| params = sum(p.numel() for p in model.parameters()) |
| return params / 1e9 |
|
|
| class UniversalTrainer: |
| """Trainer that adapts to any hardware""" |
| |
| def __init__(self, model, tokenizer, hardware_config): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.device = hardware_config['device'] |
| self.batch_size = hardware_config['batch_size'] |
| self.gradient_accumulation = hardware_config['gradient_accumulation'] |
| self.precision = hardware_config['precision'] |
| |
| |
| model_size = get_model_size(model) |
| if model_size < 1: |
| base_lr = 5e-4 |
| elif model_size < 4: |
| base_lr = 3e-4 |
| elif model_size < 8: |
| base_lr = 2e-4 |
| else: |
| base_lr = 1e-4 |
| |
| self.learning_rate = base_lr |
| |
| self.optimizer = torch.optim.AdamW( |
| self.model.parameters(), |
| lr=self.learning_rate, |
| weight_decay=0.1, |
| betas=(0.9, 0.95) |
| ) |
| |
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
| self.optimizer, T_0=5000, T_mult=2 |
| ) |
| |
| self.step = 0 |
| self.total_loss = 0 |
| |
| |
| self.scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None |
| |
| print(f"\nTraining configuration:") |
| print(f" Device: {self.device}") |
| print(f" Learning rate: {self.learning_rate}") |
| print(f" Batch size: {self.batch_size}") |
| print(f" Gradient accumulation: {self.gradient_accumulation}") |
| print(f" Precision: {self.precision}") |
| |
| def train_step(self, text): |
| """Single training step with mixed precision""" |
| self.model.train() |
| |
| |
| inputs = self.tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512, |
| padding="max_length" |
| ) |
| |
| input_ids = inputs['input_ids'].to(self.device) |
| |
| |
| if self.precision == "bfloat16" and torch.cuda.is_available(): |
| with torch.autocast(device_type='cuda', dtype=torch.bfloat16): |
| logits = self.model(input_ids) |
| loss = self._compute_loss(logits, input_ids) |
| elif self.precision == "float16" and torch.cuda.is_available(): |
| with torch.autocast(device_type='cuda', dtype=torch.float16): |
| logits = self.model(input_ids) |
| loss = self._compute_loss(logits, input_ids) |
| else: |
| logits = self.model(input_ids) |
| loss = self._compute_loss(logits, input_ids) |
| |
| |
| if self.scaler: |
| self.scaler.scale(loss).backward() |
| else: |
| loss.backward() |
| |
| |
| if (self.step + 1) % self.gradient_accumulation == 0: |
| if self.scaler: |
| self.scaler.unscale_(self.optimizer) |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| else: |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| self.optimizer.step() |
| |
| self.scheduler.step() |
| self.optimizer.zero_grad() |
| |
| self.step += 1 |
| return loss.item() |
| |
| def _compute_loss(self, logits, input_ids): |
| """Compute cross-entropy loss""" |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = input_ids[..., 1:].contiguous() |
| |
| return nn.functional.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=self.tokenizer.pad_token_id |
| ) |
| |
| def train(self, data, num_epochs=1, save_every=5000): |
| """Full training loop""" |
| print(f"\n{'='*70}") |
| print(f"STARTING TRAINING") |
| print(f"{'='*70}") |
| print(f"Examples: {len(data):,}") |
| print(f"Epochs: {num_epochs}") |
| print(f"Save checkpoint every {save_every} steps") |
| |
| for epoch in range(num_epochs): |
| print(f"\nEpoch {epoch + 1}/{num_epochs}") |
| print("-" * 40) |
| |
| |
| random.shuffle(data) |
| |
| total_loss = 0 |
| steps = 0 |
| self.optimizer.zero_grad() |
| |
| pbar = tqdm(data, desc=f"Training") |
| |
| for i, item in enumerate(pbar): |
| |
| text = item.get('text', '') |
| if not text: |
| text = f"{item.get('prompt', '')}\n{item.get('response', '')}" |
| |
| if not text or len(text) < 10: |
| continue |
| |
| try: |
| loss = self.train_step(text[:2048]) |
| total_loss += loss |
| steps += 1 |
| |
| |
| avg_loss = total_loss / steps |
| pbar.set_postfix({ |
| 'loss': f'{loss:.4f}', |
| 'avg': f'{avg_loss:.4f}' |
| }) |
| |
| |
| if steps % save_every == 0: |
| checkpoint = { |
| 'step': self.step, |
| 'epoch': epoch + 1, |
| 'model_state': self.model.state_dict(), |
| 'optimizer_state': self.optimizer.state_dict(), |
| 'loss': loss, |
| 'avg_loss': avg_loss |
| } |
| torch.save(checkpoint, f"./outputs/checkpoint_step_{self.step}.pt") |
| print(f"\n 💾 Checkpoint saved at step {self.step}") |
| |
| except Exception as e: |
| if steps < 10: |
| print(f"\n ⚠ Error: {e}") |
| continue |
| |
| avg_loss = total_loss / steps if steps > 0 else 0 |
| print(f"\nEpoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}") |
| |
| |
| torch.save({ |
| 'epoch': epoch + 1, |
| 'model_state': self.model.state_dict(), |
| 'optimizer_state': self.optimizer.state_dict(), |
| 'avg_loss': avg_loss |
| }, f"./outputs/epoch_{epoch + 1}.pt") |
| print(f" 💾 Saved epoch checkpoint") |
|
|
| def load_training_data(data_path, max_examples=None): |
| """Load training data from JSONL file""" |
| data = [] |
| data_path = Path(data_path) |
| |
| if not data_path.exists(): |
| return [] |
| |
| with open(data_path, 'r') as f: |
| for i, line in enumerate(f): |
| if max_examples and i >= max_examples: |
| break |
| try: |
| item = json.loads(line) |
| data.append(item) |
| except: |
| continue |
| |
| return data |
|
|
| def main(): |
| print("=" * 70) |
| print("SHOREKEEPER UNIVERSAL TRAINING") |
| print="=" * 70) |
| |
| |
| hw_config = detect_hardware() |
| device = hw_config['device'] |
| |
| |
| config_path = "configs/model.yaml" |
| if Path("configs/model_15b.yaml").exists(): |
| print("\n📁 Found 15B config, using that") |
| config_path = "configs/model_15b.yaml" |
| |
| |
| print("\n1. Loading SHOREKEEPER model...") |
| model = SHOREKEEPER(config_path=config_path) |
| model = model.to(device) |
| |
| model_size = get_model_size(model) |
| print(f" Model size: {model_size:.1f}B parameters") |
| print(f" Memory usage estimate: {model_size * 4:.1f} GB (fp32)") |
| |
| |
| print("\n2. Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.model_max_length = 512 |
| print(" ✓ GPT-2 tokenizer") |
| |
| |
| print("\n3. Loading training data...") |
| |
| |
| data_paths = [ |
| "./data/15b_data/15b_train.jsonl", |
| "./data/stem/stem_train.jsonl", |
| "./data/processed/train_large.jsonl", |
| "./data/processed/train.jsonl" |
| ] |
| |
| data = [] |
| for path in data_paths: |
| if Path(path).exists(): |
| data = load_training_data(path) |
| if data: |
| print(f" ✓ Loaded {len(data):,} examples from {path}") |
| break |
| |
| if not data: |
| print("\n❌ No training data found!") |
| print("\nPlease run one of these first:") |
| print(" python3 scripts/01_download_stem_data.py") |
| print(" python3 scripts/01_download_15b_data.py") |
| return |
| |
| |
| print("\n" + "=" * 70) |
| print("TRAINING OPTIONS") |
| print("=" * 70) |
| print(f"1. Quick test (10% of data, 1 epoch)") |
| print(f"2. Standard training (all data, 3 epochs)") |
| print(f"3. Full training (all data, 10 epochs)") |
| print(f"4. Custom (enter your own settings)") |
| |
| choice = input("\nChoose option (1-4): ").strip() |
| |
| if choice == "1": |
| data = data[:max(1000, len(data) // 10)] |
| epochs = 1 |
| elif choice == "2": |
| epochs = 3 |
| elif choice == "3": |
| epochs = 10 |
| elif choice == "4": |
| epochs = int(input("Number of epochs: ").strip()) |
| limit = input("Limit examples (press Enter for all): ").strip() |
| if limit: |
| data = data[:int(limit)] |
| else: |
| epochs = 1 |
| |
| |
| trainer = UniversalTrainer(model, tokenizer, hw_config) |
| |
| |
| print(f"\n4. Starting training on {len(data):,} examples for {epochs} epochs...") |
| print(" Press Ctrl+C to stop and save checkpoint\n") |
| |
| try: |
| trainer.train(data, num_epochs=epochs) |
| except KeyboardInterrupt: |
| print("\n\n⚠ Training interrupted by user") |
| print("Saving current model...") |
| torch.save(model.state_dict(), "./outputs/shorekeeper_interrupted.pt") |
| print("Model saved to: ./outputs/shorekeeper_interrupted.pt") |
| except Exception as e: |
| print(f"\n❌ Training error: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| |
| final_path = "./outputs/shorekeeper_final.pt" |
| torch.save(model.state_dict(), final_path) |
| print(f"\n✅ Model saved to: {final_path}") |
| |
| print("\n" + "=" * 70) |
| print("NEXT STEPS") |
| print("=" * 70) |
| print("1. Test your model:") |
| print(" python3 scripts/07_run_shorekeeper.py") |
| print("\n2. Convert to 4-bit for inference:") |
| print(" python3 scripts/06_convert_to_4bit.py") |
| print("\n3. Run GRPO reasoning training:") |
| print(" python3 scripts/05_grpo_train.py") |
|
|
| if __name__ == "__main__": |
| main() |
|
|