File size: 4,989 Bytes
091afb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
CRNN+CTC Model β€” simplified for small datasets (~5000-10000 samples)
~700K parameters, converges reliably without CTC blank collapse.
"""
import torch
import torch.nn as nn


class CRNN_CivilRegistry(nn.Module):

    def __init__(self, img_height=64, num_chars=96, hidden_size=128, num_lstm_layers=1,
                 dropout=0.3):
        super().__init__()

        # CNN β€” width reductions for 512px input:
        # MaxPool(2,2): 512β†’256, MaxPool(2,2): 256β†’128
        # MaxPool(2,1): 128 (height only), MaxPool(2,1): 128 (height only)
        # Conv(k=2,p=0): 127  β†’  seq_len=127, fits labels up to 64 chars
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 1)),

            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 1)),

            nn.Conv2d(256, 256, kernel_size=2, padding=0),
            nn.BatchNorm2d(256), nn.ReLU(inplace=True),
        )

        # FIXED Bug 4: derive cnn_out_h from a real forward pass instead of
        # a hardcoded formula β€” safer if architecture or img_height ever changes.
        with torch.no_grad():
            _dummy = torch.zeros(1, 1, img_height, 32)
            _out   = self.cnn(_dummy)
            cnn_out_h = _out.shape[2]   # actual height after all CNN layers
        rnn_input = 256 * cnn_out_h

        self.rnn = nn.LSTM(
            input_size=rnn_input,
            hidden_size=hidden_size,
            num_layers=num_lstm_layers,
            bidirectional=True,
            batch_first=False,
        )
        # Dropout before FC β€” prevents overfitting on small datasets.
        # Applied after BiLSTM output, before character projection.
        # p=0.3 is standard for CRNN OCR models (disabled at inference via model.eval()).
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(hidden_size * 2, num_chars)

    def forward(self, x):
        f = self.cnn(x)
        B, C, h, w = f.size()
        f = f.permute(3, 0, 1, 2).reshape(w, B, C * h)
        f, _ = self.rnn(f)
        return self.fc(self.dropout(f))


class CRNN_Ensemble(nn.Module):
    def __init__(self, num_models=3, **kwargs):
        super().__init__()
        self.models = nn.ModuleList([CRNN_CivilRegistry(**kwargs) for _ in range(num_models)])

    def forward(self, x):
        # FIXED Rec 3: average softmax probabilities across models (correct ensemble),
        # then return log of the average so CTCLoss receives log-probabilities β€”
        # the same contract as CRNN_CivilRegistry (raw logits + log_softmax in trainer).
        # Returning raw averaged probabilities caused CTCLoss to receive un-logged values.
        probs = [torch.nn.functional.softmax(m(x), dim=2) for m in self.models]
        avg_probs = torch.mean(torch.stack(probs), dim=0)
        return torch.log(avg_probs.clamp(min=1e-9))  # log-probs, safe clamp avoids log(0)


def get_crnn_model(model_type='standard', **kwargs):
    if model_type == 'ensemble':
        return CRNN_Ensemble(**kwargs)
    return CRNN_CivilRegistry(**kwargs)


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LSTM):
            for name, param in m.named_parameters():
                if 'weight' in name:
                    nn.init.orthogonal_(param)
                elif 'bias' in name:
                    nn.init.constant_(param, 0)
                    # Rec 1: set forget gate bias to 1.0 β€” helps the model
                    # remember across long sequences at the start of training.
                    # LSTM gate order: [input | forget | cell | output]
                    n = param.size(0)
                    param.data[n // 4 : n // 2].fill_(1.0)


if __name__ == "__main__":
    model = get_crnn_model('standard', img_height=64, num_chars=96, hidden_size=128, num_lstm_layers=1)
    initialize_weights(model)
    x = torch.randn(2, 1, 64, 512)
    out = model(x)
    params = sum(p.numel() for p in model.parameters())
    print(f"Output: {out.shape}  seq_len={out.shape[0]}")
    print(f"Params: {params:,}  (unchanged β€” dropout adds no parameters)")
    print(f"Dropout p=0.3 active during training, disabled during model.eval()")