| |
| |
|
|
| """ |
| Kaşgarlı Testi - Turkish Wikipedia Benchmark |
| Hypothesis: Byte-level models learn agglutinative languages more efficiently. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| from src.models.agiformer import AGIFORMER |
| from src.data.turkish_wiki import get_turkish_wiki_dataloader |
| import time |
| import json |
| import os |
|
|
| |
| D_MODEL = 512 |
| N_LAYERS = 6 |
| NUM_HEADS = 8 |
| PATCH_SIZE = 4 |
| WINDOW_SIZE = 128 |
| THINKING_STEPS = 3 |
|
|
| BATCH_SIZE = 4 |
| SEQ_LEN = 1024 |
| MAX_STEPS = 5000 |
| BASE_LR = 3e-4 |
| WARMUP_STEPS = 100 |
| GRAD_CLIP = 0.5 |
|
|
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| def train_turkish(): |
| """ |
| Train AGIFORMER on Turkish Wikipedia. |
| Logs metrics for comparison with English baseline. |
| """ |
| print("=" * 60) |
| print("KAŞGARLI TESTİ - Turkish Wikipedia Benchmark") |
| print("=" * 60) |
| |
| |
| model = AGIFORMER( |
| d_model=D_MODEL, |
| n_layers=N_LAYERS, |
| num_heads=NUM_HEADS, |
| patch_size=PATCH_SIZE, |
| window_size=WINDOW_SIZE, |
| thinking_steps=THINKING_STEPS |
| ).to(DEVICE) |
| |
| print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters") |
| print(f"Device: {DEVICE}") |
| |
| |
| |
| from src.data.clean_turkish_data import get_clean_loader |
| |
| train_loader = get_clean_loader( |
| data_dir="./data", |
| batch_size=BATCH_SIZE, |
| seq_len=SEQ_LEN, |
| split="train" |
| ) |
| |
| val_loader = get_clean_loader( |
| data_dir="./data", |
| batch_size=BATCH_SIZE, |
| seq_len=SEQ_LEN, |
| split="val" |
| ) |
| |
| |
| optimizer = optim.AdamW(model.parameters(), lr=BASE_LR) |
| scaler = torch.cuda.amp.GradScaler() |
| criterion = nn.CrossEntropyLoss() |
| |
| |
| model.train() |
| step = 0 |
| best_val_loss = float('inf') |
| |
| |
| metrics = {"train_bpc": [], "val_bpc": [], "steps": []} |
| |
| start_time = time.time() |
| |
| for epoch in range(100): |
| for batch_idx, (input_seq, target_seq) in enumerate(train_loader): |
| if step >= MAX_STEPS: |
| break |
| |
| input_seq = input_seq.to(DEVICE) |
| target_seq = target_seq.to(DEVICE) |
| |
| |
| if step < WARMUP_STEPS: |
| lr = BASE_LR * (step + 1) / WARMUP_STEPS |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = lr |
| |
| optimizer.zero_grad() |
| |
| |
| with torch.cuda.amp.autocast(enabled=(DEVICE=='cuda')): |
| logits = model(input_seq, target_bytes=target_seq) |
| |
| |
| B, N, P, V = logits.shape |
| loss = criterion( |
| logits.contiguous().view(-1, V), |
| target_seq.contiguous().view(-1) |
| ) |
| |
| |
| if torch.isnan(loss): |
| print(f"⚠️ NaN detected at step {step}! Skipping batch...") |
| continue |
| |
| |
| bpc = loss.item() / torch.log(torch.tensor(2.0)).item() |
| |
| |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| scaler.step(optimizer) |
| scaler.update() |
| |
| |
| current_lr = optimizer.param_groups[0]['lr'] |
| if step % 10 == 0: |
| print(f"Step {step}: Loss = {loss.item():.4f} | BPC = {bpc:.4f} | LR = {current_lr:.2e}") |
| metrics["train_bpc"].append(bpc) |
| metrics["steps"].append(step) |
| |
| |
| if step % 200 == 0 and step > 0: |
| val_loss, val_bpc = validate(model, val_loader, criterion) |
| print(f"-- VALIDATION: Loss = {val_loss:.4f} | BPC = {val_bpc:.4f} --") |
| |
| metrics["val_bpc"].append(val_bpc) |
| |
| |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| torch.save(model.state_dict(), "best_model_turkish.pth") |
| print("Saved best model (Turkish).") |
| |
| model.train() |
| |
| step += 1 |
| |
| if step >= MAX_STEPS: |
| break |
| |
| |
| print("Saving last model state...") |
| torch.save(model.state_dict(), "last_model_turkish.pth") |
| print("Saved last_model_turkish.pth") |
| |
| |
| with open("metrics_turkish.json", "w") as f: |
| json.dump(metrics, f, indent=2) |
| |
| elapsed = time.time() - start_time |
| print(f"Training finished in {elapsed:.2f}s") |
| print(f"Final validation BPC: {best_val_loss / torch.log(torch.tensor(2.0)).item():.4f}") |
|
|
| def validate(model, val_loader, criterion): |
| """Validation loop""" |
| model.eval() |
| total_loss = 0 |
| count = 0 |
| |
| with torch.no_grad(): |
| for input_seq, target_seq in val_loader: |
| input_seq = input_seq.to(DEVICE) |
| target_seq = target_seq.to(DEVICE) |
| |
| logits = model(input_seq, target_bytes=target_seq) |
| |
| B, N, P, V = logits.shape |
| loss = criterion( |
| logits.contiguous().view(-1, V), |
| target_seq.contiguous().view(-1) |
| ) |
| |
| total_loss += loss.item() |
| count += 1 |
| |
| if count >= 50: |
| break |
| |
| avg_loss = total_loss / count |
| bpc = avg_loss / torch.log(torch.tensor(2.0)).item() |
| |
| return avg_loss, bpc |
|
|
| if __name__ == "__main__": |
| train_turkish() |
|
|