File size: 4,307 Bytes
2411029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import json
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm

from src.data import make_loaders, Batch
from src.model import CRNN


def load_vocab(path: str = "models/vocab.json") -> Dict[str, int]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def invert_vocab(stoi: Dict[str, int]) -> Dict[int, str]:
    return {i: ch for ch, i in stoi.items()}


def greedy_decode(log_probs: torch.Tensor, itos: Dict[int, str]) -> List[str]:
    """

    log_probs: [T, B, C]

    Greedy CTC decode: argmax at each timestep, remove repeats, remove blank(0)

    """
    preds = torch.argmax(log_probs, dim=-1)  # [T, B]
    T, B = preds.shape
    out = []
    for b in range(B):
        seq = preds[:, b].tolist()
        collapsed = []
        prev = None
        for p in seq:
            if p == prev:
                continue
            if p != 0:  # skip blank
                collapsed.append(p)
            prev = p
        out.append("".join([itos[i] for i in collapsed if i in itos and i != 0]))
    return out


def run_epoch(model: nn.Module,

              loader,

              optimizer,

              criterion,

              device: torch.device,

              train: bool,

              itos=None) -> Tuple[float, List[Tuple[str, str]]]:

    model.train(train)
    total_loss = 0.0
    n_batches = 0
    samples_preview = []

    for batch in tqdm(loader, desc=("train" if train else "val")):
        batch: Batch
        images = batch.images.to(device)
        targets = batch.targets.to(device)
        target_lengths = batch.target_lengths.to(device)

        log_probs, input_lengths = model(images)  # log_probs [T,B,C], input_lengths [B]

        # CTCLoss expects:
        # log_probs: [T,B,C]
        # targets: [sum(target_lengths)]
        # input_lengths: [B]
        # target_lengths: [B]
        loss = criterion(log_probs, targets, input_lengths, target_lengths)

        if train:
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

        total_loss += loss.item()
        n_batches += 1

        # collect a small preview on val
        if (not train) and itos is not None and len(samples_preview) < 3:
            preds = greedy_decode(log_probs.detach().cpu(), itos)
            for gt, pr in zip(batch.texts, preds):
                if len(samples_preview) < 3:
                    samples_preview.append((gt, pr))

    return total_loss / max(n_batches, 1), samples_preview


def main():
    os.makedirs("models", exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    # loaders builds vocab and saves models/vocab.json
    train_loader, val_loader, stoi = make_loaders(batch_size=8, img_height=64, num_workers=0)
    itos = invert_vocab(stoi)

    model = CRNN(num_classes=len(stoi)).to(device)

    criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    optimizer = Adam(model.parameters(), lr=1e-3)

    best_val = float("inf")

    epochs = 10
    for epoch in range(1, epochs + 1):
        print(f"\nEpoch {epoch}/{epochs}")

        train_loss, _ = run_epoch(model, train_loader, optimizer, criterion, device, train=True)
        val_loss, preview = run_epoch(model, val_loader, optimizer, criterion, device, train=False, itos=itos)

        print(f"Train loss: {train_loss:.4f} | Val loss: {val_loss:.4f}")

        if preview:
            print("\nPreview (GT -> Pred):")
            for gt, pr in preview:
                print("GT :", gt)
                print("PR :", pr)
                print("-" * 40)

        # save best
        if val_loss < best_val:
            best_val = val_loss
            ckpt = {
                "model_state": model.state_dict(),
                "stoi": stoi,
                "img_height": 64,
            }
            torch.save(ckpt, "models/best.pt")
            print("✅ Saved models/best.pt")

    print("\nDone. Best val loss:", best_val)


if __name__ == "__main__":
    main()