| import torch |
| import torch.nn as nn |
|
|
| class CRNN(nn.Module): |
| def __init__(self, img_channel, img_height, img_width, num_class, |
| map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False): |
| super(CRNN, self).__init__() |
| |
| self.cnn, (output_channel, output_height, output_width) = \ |
| self._cnn_backbone(img_channel, img_height, img_width, leaky_relu) |
| |
| self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden) |
| |
| self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True, batch_first=True) |
| self.rnn2 = nn.LSTM(rnn_hidden * 2, rnn_hidden, bidirectional=True, batch_first=True) |
| |
| self.dense = nn.Linear(rnn_hidden * 2, num_class) |
| |
| def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu): |
| assert img_height % 16 == 0 |
| assert img_width % 4 == 0 |
| |
| channels = [img_channel, 64, 128, 256, 256, 512, 512, 512] |
| kernel_sizes = [3, 3, 3, 3, 3, 3, 2] |
| strides = [1, 1, 1, 1, 1, 1, 1] |
| paddings = [1, 1, 1, 1, 1, 1, 0] |
| |
| cnn = nn.Sequential() |
| |
| def conv_relu(i, batch_normalization=False): |
| n_in = channels[i] |
| n_out = channels[i+1] |
| cnn.add_module(f'conv{i}', nn.Conv2d(n_in, n_out, kernel_sizes[i], strides[i], paddings[i])) |
| if batch_normalization: |
| cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(n_out)) |
| if leaky_relu: |
| cnn.add_module(f'relu{i}', nn.LeakyReLU(0.2, inplace=True)) |
| else: |
| cnn.add_module(f'relu{i}', nn.ReLU(inplace=True)) |
| |
| conv_relu(0) |
| cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2)) |
| conv_relu(1) |
| cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2)) |
| conv_relu(2, True) |
| conv_relu(3) |
| cnn.add_module('pooling2', nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))) |
| conv_relu(4, True) |
| conv_relu(5) |
| cnn.add_module('pooling3', nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))) |
| conv_relu(6, True) |
| |
| output_channel, output_height, output_width = channels[-1], img_height // 16 - 1, img_width // 4 + 1 |
| return cnn, (output_channel, output_height, output_width) |
| |
| def forward(self, images): |
| |
| conv = self.cnn(images) |
| batch, channel, height, width = conv.size() |
| |
| conv = conv.view(batch, channel * height, width) |
| conv = conv.permute(0, 2, 1) |
| |
| seq = self.map_to_seq(conv) |
| |
| recurrent, _ = self.rnn1(seq) |
| recurrent, _ = self.rnn2(recurrent) |
| |
| output = self.dense(recurrent) |
| |
| |
| |
| |
| return output.log_softmax(2) |
|
|
| if __name__ == '__main__': |
| |
| dummy_input = torch.randn(1, 1, 32, 1024) |
| model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=80) |
| output = model(dummy_input) |
| print(f"Output shape: {output.shape}") |
|
|