""" CRNN+CTC Model — simplified for small datasets (~5000-10000 samples) ~700K parameters, converges reliably without CTC blank collapse. """ import torch import torch.nn as nn class CRNN_CivilRegistry(nn.Module): def __init__(self, img_height=64, num_chars=96, hidden_size=128, num_lstm_layers=1, dropout=0.3): super().__init__() # CNN — width reductions for 512px input: # MaxPool(2,2): 512→256, MaxPool(2,2): 256→128 # MaxPool(2,1): 128 (height only), MaxPool(2,1): 128 (height only) # Conv(k=2,p=0): 127 → seq_len=127, fits labels up to 64 chars self.cnn = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d((2, 1)), nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d((2, 1)), nn.Conv2d(256, 256, kernel_size=2, padding=0), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) # FIXED Bug 4: derive cnn_out_h from a real forward pass instead of # a hardcoded formula — safer if architecture or img_height ever changes. with torch.no_grad(): _dummy = torch.zeros(1, 1, img_height, 32) _out = self.cnn(_dummy) cnn_out_h = _out.shape[2] # actual height after all CNN layers rnn_input = 256 * cnn_out_h self.rnn = nn.LSTM( input_size=rnn_input, hidden_size=hidden_size, num_layers=num_lstm_layers, bidirectional=True, batch_first=False, ) # Dropout before FC — prevents overfitting on small datasets. # Applied after BiLSTM output, before character projection. # p=0.3 is standard for CRNN OCR models (disabled at inference via model.eval()). self.dropout = nn.Dropout(p=dropout) self.fc = nn.Linear(hidden_size * 2, num_chars) def forward(self, x): f = self.cnn(x) B, C, h, w = f.size() f = f.permute(3, 0, 1, 2).reshape(w, B, C * h) f, _ = self.rnn(f) return self.fc(self.dropout(f)) class CRNN_Ensemble(nn.Module): def __init__(self, num_models=3, **kwargs): super().__init__() self.models = nn.ModuleList([CRNN_CivilRegistry(**kwargs) for _ in range(num_models)]) def forward(self, x): # FIXED Rec 3: average softmax probabilities across models (correct ensemble), # then return log of the average so CTCLoss receives log-probabilities — # the same contract as CRNN_CivilRegistry (raw logits + log_softmax in trainer). # Returning raw averaged probabilities caused CTCLoss to receive un-logged values. probs = [torch.nn.functional.softmax(m(x), dim=2) for m in self.models] avg_probs = torch.mean(torch.stack(probs), dim=0) return torch.log(avg_probs.clamp(min=1e-9)) # log-probs, safe clamp avoids log(0) def get_crnn_model(model_type='standard', **kwargs): if model_type == 'ensemble': return CRNN_Ensemble(**kwargs) return CRNN_CivilRegistry(**kwargs) def initialize_weights(model): for m in model.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LSTM): for name, param in m.named_parameters(): if 'weight' in name: nn.init.orthogonal_(param) elif 'bias' in name: nn.init.constant_(param, 0) # Rec 1: set forget gate bias to 1.0 — helps the model # remember across long sequences at the start of training. # LSTM gate order: [input | forget | cell | output] n = param.size(0) param.data[n // 4 : n // 2].fill_(1.0) if __name__ == "__main__": model = get_crnn_model('standard', img_height=64, num_chars=96, hidden_size=128, num_lstm_layers=1) initialize_weights(model) x = torch.randn(2, 1, 64, 512) out = model(x) params = sum(p.numel() for p in model.parameters()) print(f"Output: {out.shape} seq_len={out.shape[0]}") print(f"Params: {params:,} (unchanged — dropout adds no parameters)") print(f"Dropout p=0.3 active during training, disabled during model.eval()")