import torch import torch.nn as nn class CRNN(nn.Module): def __init__(self, img_channel, img_height, img_width, num_class, map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False): super(CRNN, self).__init__() self.cnn, (output_channel, output_height, output_width) = \ self._cnn_backbone(img_channel, img_height, img_width, leaky_relu) self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden) self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True, batch_first=True) self.rnn2 = nn.LSTM(rnn_hidden * 2, rnn_hidden, bidirectional=True, batch_first=True) self.dense = nn.Linear(rnn_hidden * 2, num_class) def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu): assert img_height % 16 == 0 assert img_width % 4 == 0 channels = [img_channel, 64, 128, 256, 256, 512, 512, 512] kernel_sizes = [3, 3, 3, 3, 3, 3, 2] strides = [1, 1, 1, 1, 1, 1, 1] paddings = [1, 1, 1, 1, 1, 1, 0] cnn = nn.Sequential() def conv_relu(i, batch_normalization=False): n_in = channels[i] n_out = channels[i+1] cnn.add_module(f'conv{i}', nn.Conv2d(n_in, n_out, kernel_sizes[i], strides[i], paddings[i])) if batch_normalization: cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(n_out)) if leaky_relu: cnn.add_module(f'relu{i}', nn.LeakyReLU(0.2, inplace=True)) else: cnn.add_module(f'relu{i}', nn.ReLU(inplace=True)) conv_relu(0) cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2)) # 64x16x64 conv_relu(1) cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2)) # 128x8x32 conv_relu(2, True) conv_relu(3) cnn.add_module('pooling2', nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))) # 256x4x33 conv_relu(4, True) conv_relu(5) cnn.add_module('pooling3', nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))) # 512x2x34 conv_relu(6, True) # 512x1x33 output_channel, output_height, output_width = channels[-1], img_height // 16 - 1, img_width // 4 + 1 return cnn, (output_channel, output_height, output_width) def forward(self, images): # shape of images: (batch, channel, height, width) conv = self.cnn(images) batch, channel, height, width = conv.size() conv = conv.view(batch, channel * height, width) conv = conv.permute(0, 2, 1) # (batch, width, channel*height) seq = self.map_to_seq(conv) recurrent, _ = self.rnn1(seq) recurrent, _ = self.rnn2(recurrent) output = self.dense(recurrent) # Log softmax for CTC loss # Note: PyTorch's CTCLoss expects inputs of shape (input_length, batch_size, num_classes) # So we permute it if we are returning it for CTC loss calculation directly return output.log_softmax(2) if __name__ == '__main__': # Test model dummy_input = torch.randn(1, 1, 32, 1024) model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=80) output = model(dummy_input) print(f"Output shape: {output.shape}") # Expected: (1, 33, 80)