| |
| """ |
| Clean SHOREKEEPER training on STEM data only |
| """ |
|
|
| import sys |
| import json |
| import torch |
| import torch.nn as nn |
| from pathlib import Path |
| from tqdm import tqdm |
| import random |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from src.shorekeeper import SHOREKEEPER |
| from transformers import AutoTokenizer |
|
|
| def main(): |
| print("=" * 70) |
| print("SHOREKEEPER - STEM TRAINING") |
| print("=" * 70) |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"\nDevice: {device}") |
| |
| |
| print("\n1. Loading SHOREKEEPER model...") |
| model = SHOREKEEPER() |
| model = model.to(device) |
| print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") |
| |
| |
| print("\n2. Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
| print(" ✓ GPT-2 tokenizer") |
| |
| |
| print("\n3. Loading STEM training data...") |
| data_path = Path("./data/stem/stem_train.jsonl") |
| |
| if not data_path.exists(): |
| print(" ❌ No STEM data found!") |
| print(" Run: python3 scripts/01_download_stem_data.py") |
| return |
| |
| data = [] |
| with open(data_path, 'r') as f: |
| for line in f: |
| data.append(json.loads(line)) |
| |
| print(f" Loaded {len(data):,} examples") |
| |
| |
| batch_size = 2 |
| gradient_accumulation = 8 |
| learning_rate = 3e-4 |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1) |
| |
| print("\n4. Training configuration:") |
| print(f" Examples: {len(data):,}") |
| print(f" Learning rate: {learning_rate}") |
| print(f" Batch size: {batch_size}") |
| print(f" Gradient accumulation: {gradient_accumulation}") |
| print(f" Effective batch size: {batch_size * gradient_accumulation}") |
| |
| |
| epochs = 5 |
| print(f"\n5. Training for {epochs} epochs...") |
| |
| for epoch in range(epochs): |
| print(f"\nEpoch {epoch + 1}/{epochs}") |
| |
| |
| random.shuffle(data) |
| |
| total_loss = 0 |
| steps = 0 |
| optimizer.zero_grad() |
| |
| pbar = tqdm(data, desc=f"Training") |
| |
| for i, item in enumerate(pbar): |
| |
| text = f"{item['prompt']}\n{item['response']}" |
| |
| |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
| input_ids = inputs['input_ids'].to(device) |
| |
| |
| logits = model(input_ids) |
| |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = input_ids[..., 1:].contiguous() |
| loss = nn.functional.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=tokenizer.pad_token_id |
| ) |
| |
| |
| loss.backward() |
| |
| total_loss += loss.item() |
| steps += 1 |
| |
| |
| if (i + 1) % gradient_accumulation == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| optimizer.zero_grad() |
| |
| |
| pbar.set_postfix({'loss': f'{loss.item():.4f}', 'avg': f'{total_loss/steps:.4f}'}) |
| |
| avg_loss = total_loss / steps |
| print(f" Epoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}") |
| |
| |
| torch.save(model.state_dict(), f"./outputs/shorekeeper_stem_epoch_{epoch+1}.pt") |
| print(f" Saved: outputs/shorekeeper_stem_epoch_{epoch+1}.pt") |
| |
| |
| torch.save(model.state_dict(), "./outputs/shorekeeper_stem_final.pt") |
| print("\n✅ Training complete!") |
| print(" Final model: outputs/shorekeeper_stem_final.pt") |
|
|
| if __name__ == "__main__": |
| main() |
|
|