import os import json from typing import Dict, List, Tuple import torch import torch.nn as nn from torch.optim import Adam from tqdm import tqdm from src.data import make_loaders, Batch from src.model import CRNN def load_vocab(path: str = "models/vocab.json") -> Dict[str, int]: with open(path, "r", encoding="utf-8") as f: return json.load(f) def invert_vocab(stoi: Dict[str, int]) -> Dict[int, str]: return {i: ch for ch, i in stoi.items()} def greedy_decode(log_probs: torch.Tensor, itos: Dict[int, str]) -> List[str]: """ log_probs: [T, B, C] Greedy CTC decode: argmax at each timestep, remove repeats, remove blank(0) """ preds = torch.argmax(log_probs, dim=-1) # [T, B] T, B = preds.shape out = [] for b in range(B): seq = preds[:, b].tolist() collapsed = [] prev = None for p in seq: if p == prev: continue if p != 0: # skip blank collapsed.append(p) prev = p out.append("".join([itos[i] for i in collapsed if i in itos and i != 0])) return out def run_epoch(model: nn.Module, loader, optimizer, criterion, device: torch.device, train: bool, itos=None) -> Tuple[float, List[Tuple[str, str]]]: model.train(train) total_loss = 0.0 n_batches = 0 samples_preview = [] for batch in tqdm(loader, desc=("train" if train else "val")): batch: Batch images = batch.images.to(device) targets = batch.targets.to(device) target_lengths = batch.target_lengths.to(device) log_probs, input_lengths = model(images) # log_probs [T,B,C], input_lengths [B] # CTCLoss expects: # log_probs: [T,B,C] # targets: [sum(target_lengths)] # input_lengths: [B] # target_lengths: [B] loss = criterion(log_probs, targets, input_lengths, target_lengths) if train: optimizer.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) optimizer.step() total_loss += loss.item() n_batches += 1 # collect a small preview on val if (not train) and itos is not None and len(samples_preview) < 3: preds = greedy_decode(log_probs.detach().cpu(), itos) for gt, pr in zip(batch.texts, preds): if len(samples_preview) < 3: samples_preview.append((gt, pr)) return total_loss / max(n_batches, 1), samples_preview def main(): os.makedirs("models", exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Device:", device) # loaders builds vocab and saves models/vocab.json train_loader, val_loader, stoi = make_loaders(batch_size=8, img_height=64, num_workers=0) itos = invert_vocab(stoi) model = CRNN(num_classes=len(stoi)).to(device) criterion = nn.CTCLoss(blank=0, zero_infinity=True) optimizer = Adam(model.parameters(), lr=1e-3) best_val = float("inf") epochs = 10 for epoch in range(1, epochs + 1): print(f"\nEpoch {epoch}/{epochs}") train_loss, _ = run_epoch(model, train_loader, optimizer, criterion, device, train=True) val_loss, preview = run_epoch(model, val_loader, optimizer, criterion, device, train=False, itos=itos) print(f"Train loss: {train_loss:.4f} | Val loss: {val_loss:.4f}") if preview: print("\nPreview (GT -> Pred):") for gt, pr in preview: print("GT :", gt) print("PR :", pr) print("-" * 40) # save best if val_loss < best_val: best_val = val_loss ckpt = { "model_state": model.state_dict(), "stoi": stoi, "img_height": 64, } torch.save(ckpt, "models/best.pt") print("✅ Saved models/best.pt") print("\nDone. Best val loss:", best_val) if __name__ == "__main__": main()