File size: 4,148 Bytes
73400c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
"""
Clean SHOREKEEPER training on STEM data only
"""

import sys
import json
import torch
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
import random

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.shorekeeper import SHOREKEEPER
from transformers import AutoTokenizer

def main():
    print("=" * 70)
    print("SHOREKEEPER - STEM TRAINING")
    print("=" * 70)
    
    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nDevice: {device}")
    
    # Load model (fresh from scratch)
    print("\n1. Loading SHOREKEEPER model...")
    model = SHOREKEEPER()
    model = model.to(device)
    print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Load tokenizer
    print("\n2. Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    print("   ✓ GPT-2 tokenizer")
    
    # Load STEM data
    print("\n3. Loading STEM training data...")
    data_path = Path("./data/stem/stem_train.jsonl")
    
    if not data_path.exists():
        print("   ❌ No STEM data found!")
        print("   Run: python3 scripts/01_download_stem_data.py")
        return
    
    data = []
    with open(data_path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    
    print(f"   Loaded {len(data):,} examples")
    
    # Training config
    batch_size = 2
    gradient_accumulation = 8
    learning_rate = 3e-4
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
    
    print("\n4. Training configuration:")
    print(f"   Examples: {len(data):,}")
    print(f"   Learning rate: {learning_rate}")
    print(f"   Batch size: {batch_size}")
    print(f"   Gradient accumulation: {gradient_accumulation}")
    print(f"   Effective batch size: {batch_size * gradient_accumulation}")
    
    # Training loop
    epochs = 5
    print(f"\n5. Training for {epochs} epochs...")
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        
        # Shuffle data
        random.shuffle(data)
        
        total_loss = 0
        steps = 0
        optimizer.zero_grad()
        
        pbar = tqdm(data, desc=f"Training")
        
        for i, item in enumerate(pbar):
            # Format text
            text = f"{item['prompt']}\n{item['response']}"
            
            # Tokenize
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
            input_ids = inputs['input_ids'].to(device)
            
            # Forward
            logits = model(input_ids)
            
            # Loss
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()
            loss = nn.functional.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=tokenizer.pad_token_id
            )
            
            # Backward
            loss.backward()
            
            total_loss += loss.item()
            steps += 1
            
            # Update weights
            if (i + 1) % gradient_accumulation == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()
            
            # Update progress bar
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'avg': f'{total_loss/steps:.4f}'})
        
        avg_loss = total_loss / steps
        print(f"   Epoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}")
        
        # Save checkpoint
        torch.save(model.state_dict(), f"./outputs/shorekeeper_stem_epoch_{epoch+1}.pt")
        print(f"   Saved: outputs/shorekeeper_stem_epoch_{epoch+1}.pt")
    
    # Final save
    torch.save(model.state_dict(), "./outputs/shorekeeper_stem_final.pt")
    print("\n✅ Training complete!")
    print("   Final model: outputs/shorekeeper_stem_final.pt")

if __name__ == "__main__":
    main()