tefoteknik commited on
Commit
d7149cd
·
verified ·
1 Parent(s): 3560b46

Update AGIFORMER with Turkish benchmark

Browse files
Files changed (1) hide show
  1. train_turkish.py +194 -0
train_turkish.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Developer: inkbytefo
2
+ ## Modified: 2025-11-22
3
+
4
+ """
5
+ Kaşgarlı Testi - Turkish Wikipedia Benchmark
6
+ Hypothesis: Byte-level models learn agglutinative languages more efficiently.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.optim as optim
13
+ from src.models.agiformer import AGIFORMER
14
+ from src.data.turkish_wiki import get_turkish_wiki_dataloader
15
+ import time
16
+ import json
17
+ import os
18
+
19
+ # Configuration (IDENTICAL to English training)
20
+ D_MODEL = 512
21
+ N_LAYERS = 6
22
+ NUM_HEADS = 8
23
+ PATCH_SIZE = 4
24
+ WINDOW_SIZE = 128
25
+ THINKING_STEPS = 3
26
+
27
+ BATCH_SIZE = 4
28
+ SEQ_LEN = 1024
29
+ MAX_STEPS = 5000
30
+ BASE_LR = 3e-4
31
+ WARMUP_STEPS = 100
32
+ GRAD_CLIP = 0.5
33
+
34
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
35
+
36
+ def train_turkish():
37
+ """
38
+ Train AGIFORMER on Turkish Wikipedia.
39
+ Logs metrics for comparison with English baseline.
40
+ """
41
+ print("=" * 60)
42
+ print("KAŞGARLI TESTİ - Turkish Wikipedia Benchmark")
43
+ print("=" * 60)
44
+
45
+ # Model (same architecture)
46
+ model = AGIFORMER(
47
+ d_model=D_MODEL,
48
+ n_layers=N_LAYERS,
49
+ num_heads=NUM_HEADS,
50
+ patch_size=PATCH_SIZE,
51
+ window_size=WINDOW_SIZE,
52
+ thinking_steps=THINKING_STEPS
53
+ ).to(DEVICE)
54
+
55
+ print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")
56
+ print(f"Device: {DEVICE}")
57
+
58
+ # Data
59
+ train_loader = get_turkish_wiki_dataloader(
60
+ batch_size=BATCH_SIZE,
61
+ seq_len=SEQ_LEN,
62
+ split="train"
63
+ )
64
+
65
+ val_loader = get_turkish_wiki_dataloader(
66
+ batch_size=BATCH_SIZE,
67
+ seq_len=SEQ_LEN,
68
+ split="val"
69
+ )
70
+
71
+ # Optimizer
72
+ optimizer = optim.AdamW(model.parameters(), lr=BASE_LR)
73
+ criterion = nn.CrossEntropyLoss()
74
+
75
+ # Training loop
76
+ model.train()
77
+ step = 0
78
+ best_val_loss = float('inf')
79
+
80
+ # Metrics log
81
+ metrics = {"train_bpc": [], "val_bpc": [], "steps": []}
82
+
83
+ start_time = time.time()
84
+
85
+ for epoch in range(100): # Enough epochs to reach MAX_STEPS
86
+ for batch_idx, (input_seq, target_seq) in enumerate(train_loader):
87
+ if step >= MAX_STEPS:
88
+ break
89
+
90
+ input_seq = input_seq.to(DEVICE)
91
+ target_seq = target_seq.to(DEVICE)
92
+
93
+ # Learning rate warmup
94
+ if step < WARMUP_STEPS:
95
+ lr = BASE_LR * (step + 1) / WARMUP_STEPS
96
+ for param_group in optimizer.param_groups:
97
+ param_group['lr'] = lr
98
+
99
+ # Forward
100
+ logits = model(input_seq, target_bytes=target_seq)
101
+
102
+ # Loss
103
+ B, N, P, V = logits.shape
104
+ loss = criterion(
105
+ logits.contiguous().view(-1, V),
106
+ target_seq.contiguous().view(-1)
107
+ )
108
+
109
+ # Check for NaN
110
+ if torch.isnan(loss):
111
+ print(f"⚠️ NaN detected at step {step}! Skipping batch...")
112
+ continue
113
+
114
+ # BPC (Bits Per Character)
115
+ bpc = loss.item() / torch.log(torch.tensor(2.0)).item()
116
+
117
+ # Backward
118
+ optimizer.zero_grad()
119
+ loss.backward()
120
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
121
+ optimizer.step()
122
+
123
+ # Log
124
+ current_lr = optimizer.param_groups[0]['lr']
125
+ if step % 10 == 0:
126
+ print(f"Step {step}: Loss = {loss.item():.4f} | BPC = {bpc:.4f} | LR = {current_lr:.2e}")
127
+ metrics["train_bpc"].append(bpc)
128
+ metrics["steps"].append(step)
129
+
130
+ # Validation
131
+ if step % 200 == 0 and step > 0:
132
+ val_loss, val_bpc = validate(model, val_loader, criterion)
133
+ print(f"-- VALIDATION: Loss = {val_loss:.4f} | BPC = {val_bpc:.4f} --")
134
+
135
+ metrics["val_bpc"].append(val_bpc)
136
+
137
+ # Save best
138
+ if val_loss < best_val_loss:
139
+ best_val_loss = val_loss
140
+ torch.save(model.state_dict(), "best_model_turkish.pth")
141
+ print("Saved best model (Turkish).")
142
+
143
+ model.train()
144
+
145
+ step += 1
146
+
147
+ if step >= MAX_STEPS:
148
+ break
149
+
150
+ # Save final
151
+ print("Saving last model state...")
152
+ torch.save(model.state_dict(), "last_model_turkish.pth")
153
+ print("Saved last_model_turkish.pth")
154
+
155
+ # Save metrics
156
+ with open("metrics_turkish.json", "w") as f:
157
+ json.dump(metrics, f, indent=2)
158
+
159
+ elapsed = time.time() - start_time
160
+ print(f"Training finished in {elapsed:.2f}s")
161
+ print(f"Final validation BPC: {best_val_loss / torch.log(torch.tensor(2.0)).item():.4f}")
162
+
163
+ def validate(model, val_loader, criterion):
164
+ """Validation loop"""
165
+ model.eval()
166
+ total_loss = 0
167
+ count = 0
168
+
169
+ with torch.no_grad():
170
+ for input_seq, target_seq in val_loader:
171
+ input_seq = input_seq.to(DEVICE)
172
+ target_seq = target_seq.to(DEVICE)
173
+
174
+ logits = model(input_seq, target_bytes=target_seq)
175
+
176
+ B, N, P, V = logits.shape
177
+ loss = criterion(
178
+ logits.contiguous().view(-1, V),
179
+ target_seq.contiguous().view(-1)
180
+ )
181
+
182
+ total_loss += loss.item()
183
+ count += 1
184
+
185
+ if count >= 50: # Limit validation batches
186
+ break
187
+
188
+ avg_loss = total_loss / count
189
+ bpc = avg_loss / torch.log(torch.tensor(2.0)).item()
190
+
191
+ return avg_loss, bpc
192
+
193
+ if __name__ == "__main__":
194
+ train_turkish()