ocr / crnn_model.py
hanz245's picture
set up
7111e1a
"""
CRNN+CTC Model β€” simplified for small datasets (~5000-10000 samples)
~700K parameters, converges reliably without CTC blank collapse.
"""
import torch
import torch.nn as nn
class CRNN_CivilRegistry(nn.Module):
def __init__(self, img_height=64, num_chars=96, hidden_size=128, num_lstm_layers=1,
dropout=0.3):
super().__init__()
# CNN β€” width reductions for 512px input:
# MaxPool(2,2): 512β†’256, MaxPool(2,2): 256β†’128
# MaxPool(2,1): 128 (height only), MaxPool(2,1): 128 (height only)
# Conv(k=2,p=0): 127 β†’ seq_len=127, fits labels up to 64 chars
self.cnn = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
nn.MaxPool2d((2, 1)),
nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
nn.MaxPool2d((2, 1)),
nn.Conv2d(256, 256, kernel_size=2, padding=0),
nn.BatchNorm2d(256), nn.ReLU(inplace=True),
)
# FIXED Bug 4: derive cnn_out_h from a real forward pass instead of
# a hardcoded formula β€” safer if architecture or img_height ever changes.
with torch.no_grad():
_dummy = torch.zeros(1, 1, img_height, 32)
_out = self.cnn(_dummy)
cnn_out_h = _out.shape[2] # actual height after all CNN layers
rnn_input = 256 * cnn_out_h
self.rnn = nn.LSTM(
input_size=rnn_input,
hidden_size=hidden_size,
num_layers=num_lstm_layers,
bidirectional=True,
batch_first=False,
)
# Dropout before FC β€” prevents overfitting on small datasets.
# Applied after BiLSTM output, before character projection.
# p=0.3 is standard for CRNN OCR models (disabled at inference via model.eval()).
self.dropout = nn.Dropout(p=dropout)
self.fc = nn.Linear(hidden_size * 2, num_chars)
def forward(self, x):
f = self.cnn(x)
B, C, h, w = f.size()
f = f.permute(3, 0, 1, 2).reshape(w, B, C * h)
f, _ = self.rnn(f)
return self.fc(self.dropout(f))
class CRNN_Ensemble(nn.Module):
def __init__(self, num_models=3, **kwargs):
super().__init__()
self.models = nn.ModuleList([CRNN_CivilRegistry(**kwargs) for _ in range(num_models)])
def forward(self, x):
# FIXED Rec 3: average softmax probabilities across models (correct ensemble),
# then return log of the average so CTCLoss receives log-probabilities β€”
# the same contract as CRNN_CivilRegistry (raw logits + log_softmax in trainer).
# Returning raw averaged probabilities caused CTCLoss to receive un-logged values.
probs = [torch.nn.functional.softmax(m(x), dim=2) for m in self.models]
avg_probs = torch.mean(torch.stack(probs), dim=0)
return torch.log(avg_probs.clamp(min=1e-9)) # log-probs, safe clamp avoids log(0)
def get_crnn_model(model_type='standard', **kwargs):
if model_type == 'ensemble':
return CRNN_Ensemble(**kwargs)
return CRNN_CivilRegistry(**kwargs)
def initialize_weights(model):
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LSTM):
for name, param in m.named_parameters():
if 'weight' in name:
nn.init.orthogonal_(param)
elif 'bias' in name:
nn.init.constant_(param, 0)
# Rec 1: set forget gate bias to 1.0 β€” helps the model
# remember across long sequences at the start of training.
# LSTM gate order: [input | forget | cell | output]
n = param.size(0)
param.data[n // 4 : n // 2].fill_(1.0)
if __name__ == "__main__":
model = get_crnn_model('standard', img_height=64, num_chars=96, hidden_size=128, num_lstm_layers=1)
initialize_weights(model)
x = torch.randn(2, 1, 64, 512)
out = model(x)
params = sum(p.numel() for p in model.parameters())
print(f"Output: {out.shape} seq_len={out.shape[0]}")
print(f"Params: {params:,} (unchanged β€” dropout adds no parameters)")
print(f"Dropout p=0.3 active during training, disabled during model.eval()")