File size: 2,957 Bytes
6e89f30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
from src.config import cfg

class SmallCNN(nn.Module):
    """
    Improved CNN with BatchNorm and residual connections.
    Produces feature map with total stride 4 along width,
    and compresses height to ~1 via pooling.
    """
    def __init__(self, in_ch=1) -> None:
        super().__init__()
        # First conv block: H,W -> H/2, W/2
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))  # stride 2x2
        )
        
        # Second conv block: maintain H/2, W/2 -> W/4
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(1,2), stride=(1,2))  # height stride 1, width stride 2
        )
        
        # Residual block at 128 channels
        self.residual = nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128)
        )
        self.residual_relu = nn.ReLU(inplace=True)
        
        self.height_pool = nn.AdaptiveAvgPool2d((1, None))  # squeeze height to 1

    def forward(self, x):
        # First two conv blocks
        f = self.conv1(x)           # [B, 64, H/2, W/2]
        f = self.conv2(f)           # [B, 128, H/2, W/4]
        
        # Residual connection
        residual = f
        f = self.residual(f)        # [B, 128, H/2, W/4]
        f = f + residual            # Skip connection
        f = self.residual_relu(f)   # [B, 128, H/2, W/4]
        
        # Height pooling
        f = self.height_pool(f)     # [B, 128, 1, W/4]
        f = f.squeeze(2)            # [B, 128, W/4]
        f = f.permute(2, 0, 1)      # [T(=W/4), B, 128]
        return f

class CRNN(nn.Module):
    def __init__(self, vocab_size: int, in_ch: int = 1, hidden: int = 320, layers: int = 2, dropout: float = 0.05):
        super().__init__()
        self.cnn = SmallCNN(in_ch=in_ch)
        self.rnn = nn.LSTM(input_size=128, hidden_size=hidden, num_layers=layers,
                           bidirectional=True, dropout=dropout, batch_first=False)
        self.norm = nn.LayerNorm(2*hidden)  # Add LayerNorm for stability
        self.fc = nn.Linear(2*hidden, vocab_size)
        
        # Initialize weights properly
        self._init_weights()
    
    def _init_weights(self):
        # Initialize final linear layer with small weights
        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, x):
        seq = self.cnn(x)   # [T,B,C=128]
        y, _ = self.rnn(seq)  # [T,B,2H]
        y = self.norm(y)     # [T,B,2H] - Apply LayerNorm
        logits = self.fc(y)   # [T,B,V]
        return logits