|
|
import os
|
|
|
import json
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.optim import Adam
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
from src.data import make_loaders, Batch
|
|
|
from src.model import CRNN
|
|
|
|
|
|
|
|
|
def load_vocab(path: str = "models/vocab.json") -> Dict[str, int]:
|
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
|
return json.load(f)
|
|
|
|
|
|
|
|
|
def invert_vocab(stoi: Dict[str, int]) -> Dict[int, str]:
|
|
|
return {i: ch for ch, i in stoi.items()}
|
|
|
|
|
|
|
|
|
def greedy_decode(log_probs: torch.Tensor, itos: Dict[int, str]) -> List[str]:
|
|
|
"""
|
|
|
log_probs: [T, B, C]
|
|
|
Greedy CTC decode: argmax at each timestep, remove repeats, remove blank(0)
|
|
|
"""
|
|
|
preds = torch.argmax(log_probs, dim=-1)
|
|
|
T, B = preds.shape
|
|
|
out = []
|
|
|
for b in range(B):
|
|
|
seq = preds[:, b].tolist()
|
|
|
collapsed = []
|
|
|
prev = None
|
|
|
for p in seq:
|
|
|
if p == prev:
|
|
|
continue
|
|
|
if p != 0:
|
|
|
collapsed.append(p)
|
|
|
prev = p
|
|
|
out.append("".join([itos[i] for i in collapsed if i in itos and i != 0]))
|
|
|
return out
|
|
|
|
|
|
|
|
|
def run_epoch(model: nn.Module,
|
|
|
loader,
|
|
|
optimizer,
|
|
|
criterion,
|
|
|
device: torch.device,
|
|
|
train: bool,
|
|
|
itos=None) -> Tuple[float, List[Tuple[str, str]]]:
|
|
|
|
|
|
model.train(train)
|
|
|
total_loss = 0.0
|
|
|
n_batches = 0
|
|
|
samples_preview = []
|
|
|
|
|
|
for batch in tqdm(loader, desc=("train" if train else "val")):
|
|
|
batch: Batch
|
|
|
images = batch.images.to(device)
|
|
|
targets = batch.targets.to(device)
|
|
|
target_lengths = batch.target_lengths.to(device)
|
|
|
|
|
|
log_probs, input_lengths = model(images)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = criterion(log_probs, targets, input_lengths, target_lengths)
|
|
|
|
|
|
if train:
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
loss.backward()
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
|
|
|
optimizer.step()
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
n_batches += 1
|
|
|
|
|
|
|
|
|
if (not train) and itos is not None and len(samples_preview) < 3:
|
|
|
preds = greedy_decode(log_probs.detach().cpu(), itos)
|
|
|
for gt, pr in zip(batch.texts, preds):
|
|
|
if len(samples_preview) < 3:
|
|
|
samples_preview.append((gt, pr))
|
|
|
|
|
|
return total_loss / max(n_batches, 1), samples_preview
|
|
|
|
|
|
|
|
|
def main():
|
|
|
os.makedirs("models", exist_ok=True)
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print("Device:", device)
|
|
|
|
|
|
|
|
|
train_loader, val_loader, stoi = make_loaders(batch_size=8, img_height=64, num_workers=0)
|
|
|
itos = invert_vocab(stoi)
|
|
|
|
|
|
model = CRNN(num_classes=len(stoi)).to(device)
|
|
|
|
|
|
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
|
|
|
optimizer = Adam(model.parameters(), lr=1e-3)
|
|
|
|
|
|
best_val = float("inf")
|
|
|
|
|
|
epochs = 10
|
|
|
for epoch in range(1, epochs + 1):
|
|
|
print(f"\nEpoch {epoch}/{epochs}")
|
|
|
|
|
|
train_loss, _ = run_epoch(model, train_loader, optimizer, criterion, device, train=True)
|
|
|
val_loss, preview = run_epoch(model, val_loader, optimizer, criterion, device, train=False, itos=itos)
|
|
|
|
|
|
print(f"Train loss: {train_loss:.4f} | Val loss: {val_loss:.4f}")
|
|
|
|
|
|
if preview:
|
|
|
print("\nPreview (GT -> Pred):")
|
|
|
for gt, pr in preview:
|
|
|
print("GT :", gt)
|
|
|
print("PR :", pr)
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
if val_loss < best_val:
|
|
|
best_val = val_loss
|
|
|
ckpt = {
|
|
|
"model_state": model.state_dict(),
|
|
|
"stoi": stoi,
|
|
|
"img_height": 64,
|
|
|
}
|
|
|
torch.save(ckpt, "models/best.pt")
|
|
|
print("✅ Saved models/best.pt")
|
|
|
|
|
|
print("\nDone. Best val loss:", best_val)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|