File size: 4,307 Bytes
2411029 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | import os
import json
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
from src.data import make_loaders, Batch
from src.model import CRNN
def load_vocab(path: str = "models/vocab.json") -> Dict[str, int]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def invert_vocab(stoi: Dict[str, int]) -> Dict[int, str]:
return {i: ch for ch, i in stoi.items()}
def greedy_decode(log_probs: torch.Tensor, itos: Dict[int, str]) -> List[str]:
"""
log_probs: [T, B, C]
Greedy CTC decode: argmax at each timestep, remove repeats, remove blank(0)
"""
preds = torch.argmax(log_probs, dim=-1) # [T, B]
T, B = preds.shape
out = []
for b in range(B):
seq = preds[:, b].tolist()
collapsed = []
prev = None
for p in seq:
if p == prev:
continue
if p != 0: # skip blank
collapsed.append(p)
prev = p
out.append("".join([itos[i] for i in collapsed if i in itos and i != 0]))
return out
def run_epoch(model: nn.Module,
loader,
optimizer,
criterion,
device: torch.device,
train: bool,
itos=None) -> Tuple[float, List[Tuple[str, str]]]:
model.train(train)
total_loss = 0.0
n_batches = 0
samples_preview = []
for batch in tqdm(loader, desc=("train" if train else "val")):
batch: Batch
images = batch.images.to(device)
targets = batch.targets.to(device)
target_lengths = batch.target_lengths.to(device)
log_probs, input_lengths = model(images) # log_probs [T,B,C], input_lengths [B]
# CTCLoss expects:
# log_probs: [T,B,C]
# targets: [sum(target_lengths)]
# input_lengths: [B]
# target_lengths: [B]
loss = criterion(log_probs, targets, input_lengths, target_lengths)
if train:
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()
total_loss += loss.item()
n_batches += 1
# collect a small preview on val
if (not train) and itos is not None and len(samples_preview) < 3:
preds = greedy_decode(log_probs.detach().cpu(), itos)
for gt, pr in zip(batch.texts, preds):
if len(samples_preview) < 3:
samples_preview.append((gt, pr))
return total_loss / max(n_batches, 1), samples_preview
def main():
os.makedirs("models", exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
# loaders builds vocab and saves models/vocab.json
train_loader, val_loader, stoi = make_loaders(batch_size=8, img_height=64, num_workers=0)
itos = invert_vocab(stoi)
model = CRNN(num_classes=len(stoi)).to(device)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = Adam(model.parameters(), lr=1e-3)
best_val = float("inf")
epochs = 10
for epoch in range(1, epochs + 1):
print(f"\nEpoch {epoch}/{epochs}")
train_loss, _ = run_epoch(model, train_loader, optimizer, criterion, device, train=True)
val_loss, preview = run_epoch(model, val_loader, optimizer, criterion, device, train=False, itos=itos)
print(f"Train loss: {train_loss:.4f} | Val loss: {val_loss:.4f}")
if preview:
print("\nPreview (GT -> Pred):")
for gt, pr in preview:
print("GT :", gt)
print("PR :", pr)
print("-" * 40)
# save best
if val_loss < best_val:
best_val = val_loss
ckpt = {
"model_state": model.state_dict(),
"stoi": stoi,
"img_height": 64,
}
torch.save(ckpt, "models/best.pt")
print("✅ Saved models/best.pt")
print("\nDone. Best val loss:", best_val)
if __name__ == "__main__":
main()
|