Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from src.config import cfg | |
| class SmallCNN(nn.Module): | |
| """ | |
| Improved CNN with BatchNorm and residual connections. | |
| Produces feature map with total stride 4 along width, | |
| and compresses height to ~1 via pooling. | |
| """ | |
| def __init__(self, in_ch=1) -> None: | |
| super().__init__() | |
| # First conv block: H,W -> H/2, W/2 | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(in_ch, 64, 3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)) # stride 2x2 | |
| ) | |
| # Second conv block: maintain H/2, W/2 -> W/4 | |
| self.conv2 = nn.Sequential( | |
| nn.Conv2d(64, 128, 3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(kernel_size=(1,2), stride=(1,2)) # height stride 1, width stride 2 | |
| ) | |
| # Residual block at 128 channels | |
| self.residual = nn.Sequential( | |
| nn.Conv2d(128, 128, 3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(128, 128, 3, padding=1), | |
| nn.BatchNorm2d(128) | |
| ) | |
| self.residual_relu = nn.ReLU(inplace=True) | |
| self.height_pool = nn.AdaptiveAvgPool2d((1, None)) # squeeze height to 1 | |
| def forward(self, x): | |
| # First two conv blocks | |
| f = self.conv1(x) # [B, 64, H/2, W/2] | |
| f = self.conv2(f) # [B, 128, H/2, W/4] | |
| # Residual connection | |
| residual = f | |
| f = self.residual(f) # [B, 128, H/2, W/4] | |
| f = f + residual # Skip connection | |
| f = self.residual_relu(f) # [B, 128, H/2, W/4] | |
| # Height pooling | |
| f = self.height_pool(f) # [B, 128, 1, W/4] | |
| f = f.squeeze(2) # [B, 128, W/4] | |
| f = f.permute(2, 0, 1) # [T(=W/4), B, 128] | |
| return f | |
| class CRNN(nn.Module): | |
| def __init__(self, vocab_size: int, in_ch: int = 1, hidden: int = 320, layers: int = 2, dropout: float = 0.05): | |
| super().__init__() | |
| self.cnn = SmallCNN(in_ch=in_ch) | |
| self.rnn = nn.LSTM(input_size=128, hidden_size=hidden, num_layers=layers, | |
| bidirectional=True, dropout=dropout, batch_first=False) | |
| self.norm = nn.LayerNorm(2*hidden) # Add LayerNorm for stability | |
| self.fc = nn.Linear(2*hidden, vocab_size) | |
| # Initialize weights properly | |
| self._init_weights() | |
| def _init_weights(self): | |
| # Initialize final linear layer with small weights | |
| nn.init.xavier_uniform_(self.fc.weight) | |
| nn.init.zeros_(self.fc.bias) | |
| def forward(self, x): | |
| seq = self.cnn(x) # [T,B,C=128] | |
| y, _ = self.rnn(seq) # [T,B,2H] | |
| y = self.norm(y) # [T,B,2H] - Apply LayerNorm | |
| logits = self.fc(y) # [T,B,V] | |
| return logits |