lakshmi-charan's picture
Upload 15 files
2411029 verified
import os
import json
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
from src.data import make_loaders, Batch
from src.model import CRNN
def load_vocab(path: str = "models/vocab.json") -> Dict[str, int]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def invert_vocab(stoi: Dict[str, int]) -> Dict[int, str]:
return {i: ch for ch, i in stoi.items()}
def greedy_decode(log_probs: torch.Tensor, itos: Dict[int, str]) -> List[str]:
"""
log_probs: [T, B, C]
Greedy CTC decode: argmax at each timestep, remove repeats, remove blank(0)
"""
preds = torch.argmax(log_probs, dim=-1) # [T, B]
T, B = preds.shape
out = []
for b in range(B):
seq = preds[:, b].tolist()
collapsed = []
prev = None
for p in seq:
if p == prev:
continue
if p != 0: # skip blank
collapsed.append(p)
prev = p
out.append("".join([itos[i] for i in collapsed if i in itos and i != 0]))
return out
def run_epoch(model: nn.Module,
loader,
optimizer,
criterion,
device: torch.device,
train: bool,
itos=None) -> Tuple[float, List[Tuple[str, str]]]:
model.train(train)
total_loss = 0.0
n_batches = 0
samples_preview = []
for batch in tqdm(loader, desc=("train" if train else "val")):
batch: Batch
images = batch.images.to(device)
targets = batch.targets.to(device)
target_lengths = batch.target_lengths.to(device)
log_probs, input_lengths = model(images) # log_probs [T,B,C], input_lengths [B]
# CTCLoss expects:
# log_probs: [T,B,C]
# targets: [sum(target_lengths)]
# input_lengths: [B]
# target_lengths: [B]
loss = criterion(log_probs, targets, input_lengths, target_lengths)
if train:
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()
total_loss += loss.item()
n_batches += 1
# collect a small preview on val
if (not train) and itos is not None and len(samples_preview) < 3:
preds = greedy_decode(log_probs.detach().cpu(), itos)
for gt, pr in zip(batch.texts, preds):
if len(samples_preview) < 3:
samples_preview.append((gt, pr))
return total_loss / max(n_batches, 1), samples_preview
def main():
os.makedirs("models", exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
# loaders builds vocab and saves models/vocab.json
train_loader, val_loader, stoi = make_loaders(batch_size=8, img_height=64, num_workers=0)
itos = invert_vocab(stoi)
model = CRNN(num_classes=len(stoi)).to(device)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = Adam(model.parameters(), lr=1e-3)
best_val = float("inf")
epochs = 10
for epoch in range(1, epochs + 1):
print(f"\nEpoch {epoch}/{epochs}")
train_loss, _ = run_epoch(model, train_loader, optimizer, criterion, device, train=True)
val_loss, preview = run_epoch(model, val_loader, optimizer, criterion, device, train=False, itos=itos)
print(f"Train loss: {train_loss:.4f} | Val loss: {val_loss:.4f}")
if preview:
print("\nPreview (GT -> Pred):")
for gt, pr in preview:
print("GT :", gt)
print("PR :", pr)
print("-" * 40)
# save best
if val_loss < best_val:
best_val = val_loss
ckpt = {
"model_state": model.state_dict(),
"stoi": stoi,
"img_height": 64,
}
torch.save(ckpt, "models/best.pt")
print("✅ Saved models/best.pt")
print("\nDone. Best val loss:", best_val)
if __name__ == "__main__":
main()