| |
| """ |
| SHOREKEEPER-4B Training Pipeline |
| Runs on any CUDA device (RTX 3060, H100, etc.) |
| """ |
|
|
| import sys |
| import json |
| import torch |
| import torch.nn as nn |
| from pathlib import Path |
| from tqdm import tqdm |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from src.shorekeeper import MemoryEfficientSHOREKEEPER |
| from transformers import AutoTokenizer |
|
|
| class SHOREKEEPERTrainer: |
| """Simple training loop for SHOREKEEPER""" |
| |
| def __init__(self, model, tokenizer, config): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.device = next(model.parameters()).device |
| |
| self.learning_rate = config.get('learning_rate', 1e-4) |
| self.epochs = config.get('epochs', 3) |
| self.batch_size = config.get('batch_size', 2) |
| self.gradient_accumulation = config.get('gradient_accumulation', 4) |
| |
| self.optimizer = torch.optim.AdamW( |
| self.model.parameters(), |
| lr=self.learning_rate, |
| weight_decay=0.01 |
| ) |
| |
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| self.optimizer, |
| T_max=1000, |
| eta_min=1e-6 |
| ) |
| |
| self.step = 0 |
| |
| def train_step(self, batch): |
| """Single training step""" |
| self.model.train() |
| |
| |
| texts = batch['text'] |
| |
| |
| inputs = self.tokenizer( |
| texts, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=512 |
| ) |
| |
| input_ids = inputs['input_ids'].to(self.device) |
| |
| |
| logits = self.model(input_ids) |
| |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = input_ids[..., 1:].contiguous() |
| |
| |
| loss = nn.functional.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else -100 |
| ) |
| |
| |
| loss.backward() |
| |
| |
| if (self.step + 1) % self.gradient_accumulation == 0: |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| self.optimizer.step() |
| self.scheduler.step() |
| self.optimizer.zero_grad() |
| |
| self.step += 1 |
| |
| return loss.item() |
| |
| def train(self, dataset, output_dir="./outputs"): |
| """Full training loop""" |
| print(f"\n{'='*60}") |
| print("Starting Training") |
| print(f"{'='*60}") |
| print(f"Device: {self.device}") |
| print(f"Training samples: {len(dataset)}") |
| print(f"Batch size: {self.batch_size}") |
| print(f"Learning rate: {self.learning_rate}") |
| print(f"Epochs: {self.epochs}") |
| print(f"{'='*60}\n") |
| |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| for epoch in range(self.epochs): |
| print(f"\nEpoch {epoch + 1}/{self.epochs}") |
| print("-" * 40) |
| |
| total_loss = 0 |
| steps = 0 |
| |
| |
| pbar = tqdm(dataset, desc=f"Training") |
| |
| for i, item in enumerate(pbar): |
| |
| prompt = item.get('prompt', '') |
| response = item.get('response', '') |
| |
| if not prompt or not response: |
| continue |
| |
| |
| text = f"{prompt}\n{response}" |
| |
| batch = {'text': [text]} |
| |
| try: |
| loss = self.train_step(batch) |
| total_loss += loss |
| steps += 1 |
| |
| |
| pbar.set_postfix({'loss': f'{loss:.4f}'}) |
| |
| |
| if steps % 100 == 0: |
| checkpoint_path = output_dir / f"checkpoint_step_{steps}.pt" |
| torch.save({ |
| 'step': steps, |
| 'model_state': self.model.state_dict(), |
| 'optimizer_state': self.optimizer.state_dict(), |
| 'loss': loss |
| }, checkpoint_path) |
| print(f"\n Saved checkpoint: {checkpoint_path}") |
| |
| except Exception as e: |
| |
| if steps < 5: |
| print(f"\n Error on step {steps}: {e}") |
| continue |
| |
| avg_loss = total_loss / steps if steps > 0 else 0 |
| print(f"\nEpoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}") |
| |
| |
| epoch_path = output_dir / f"epoch_{epoch + 1}.pt" |
| torch.save({ |
| 'epoch': epoch + 1, |
| 'model_state': self.model.state_dict(), |
| 'optimizer_state': self.optimizer.state_dict(), |
| 'avg_loss': avg_loss |
| }, epoch_path) |
| print(f"Saved epoch checkpoint: {epoch_path}") |
| |
| |
| final_path = output_dir / "shorekeeper-4b-final.pt" |
| torch.save(self.model.state_dict(), final_path) |
| print(f"\n{'='*60}") |
| print(f"✅ Training complete! Final model saved to: {final_path}") |
| print(f"{'='*60}") |
| |
| return self.model |
|
|
| def load_data(data_path, limit=None): |
| """Load training data from JSONL file""" |
| data = [] |
| data_path = Path(data_path) |
| |
| if not data_path.exists(): |
| print(f"Data file not found: {data_path}") |
| return data |
| |
| with open(data_path, 'r') as f: |
| for i, line in enumerate(f): |
| if limit and i >= limit: |
| break |
| try: |
| item = json.loads(line) |
| data.append(item) |
| except: |
| continue |
| |
| print(f"Loaded {len(data)} examples from {data_path}") |
| return data |
|
|
| def main(): |
| print("=" * 60) |
| print("SHOREKEEPER-4B Training Pipeline") |
| print("=" * 60) |
| |
| |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| print(f"\n✓ CUDA available: {torch.cuda.get_device_name(0)}") |
| print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| else: |
| device = torch.device("cpu") |
| print("\n⚠ No GPU detected, using CPU (will be slow)") |
| |
| |
| print("\n1. Loading SHOREKEEPER model...") |
| model = MemoryEfficientSHOREKEEPER(use_4bit=False) |
| model = model.to(device) |
| print(f" Model loaded on {device}") |
| |
| |
| print("\n2. Loading tokenizer...") |
| try: |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
| print(" ✓ Using GPT-2 tokenizer") |
| except: |
| print(" ⚠ Could not load GPT-2 tokenizer") |
| return |
| |
| |
| print("\n3. Loading training data...") |
| data_path = Path("./data/processed/train.jsonl") |
| |
| if not data_path.exists(): |
| print(f"\n❌ No training data found at {data_path}") |
| print(" Run: python3 scripts/01_download_data.py") |
| print(" Then: python3 scripts/02_prepare_data.py") |
| return |
| |
| print("\n Training options:") |
| print(" [1] Quick test (50 examples, 1 epoch) - ~2 minutes") |
| print(" [2] Small training (200 examples, 3 epochs) - ~10 minutes") |
| print(" [3] Medium training (500 examples, 5 epochs) - ~30 minutes") |
| print(" [4] Full training (all data, 10 epochs) - several hours") |
| |
| choice = input("\nChoose option (1/2/3/4): ").strip() |
| |
| if choice == "1": |
| limit = 50 |
| epochs = 1 |
| learning_rate = 1e-4 |
| elif choice == "2": |
| limit = 200 |
| epochs = 3 |
| learning_rate = 5e-5 |
| elif choice == "3": |
| limit = 500 |
| epochs = 5 |
| learning_rate = 3e-5 |
| else: |
| limit = None |
| epochs = 10 |
| learning_rate = 1e-5 |
| |
| |
| data = load_data(data_path, limit=limit) |
| |
| if not data: |
| print("\n❌ No training data available!") |
| return |
| |
| print(f"\n Training with {len(data)} examples, {epochs} epochs") |
| print(f" Learning rate: {learning_rate}") |
| |
| |
| config = { |
| 'learning_rate': learning_rate, |
| 'epochs': epochs, |
| 'batch_size': 2, |
| 'gradient_accumulation': 4 |
| } |
| |
| |
| print("\n4. Initializing trainer...") |
| trainer = SHOREKEEPERTrainer(model, tokenizer, config) |
| |
| |
| print("\n5. Starting training...") |
| print(" Press Ctrl+C to stop early\n") |
| |
| try: |
| trained_model = trainer.train(data, output_dir="./outputs") |
| except KeyboardInterrupt: |
| print("\n\n⚠ Training interrupted by user") |
| print("Saving current model...") |
| torch.save(model.state_dict(), "./outputs/shorekeeper-interrupted.pt") |
| print("Model saved to: ./outputs/shorekeeper-interrupted.pt") |
| except Exception as e: |
| print(f"\n❌ Training failed: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| print("\n" + "=" * 60) |
| print("Next steps:") |
| print(" 1. Run GRPO training: python3 scripts/05_grpo_train.py") |
| print(" 2. Convert to 4-bit: python3 scripts/06_convert_to_4bit.py") |
| print(" 3. Run SHOREKEEPER: python3 scripts/07_run_shorekeeper.py") |
| print("=" * 60) |
|
|
| if __name__ == "__main__": |
| main() |
|
|