Spaces:
Build error
Build error
| import torch.nn as nn | |
| class CRNN(nn.Module): | |
| def __init__(self, img_channel, img_height, img_width, num_class, | |
| map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False): | |
| super(CRNN, self).__init__() | |
| self.cnn, (output_channel, output_height, output_width) = \ | |
| self._cnn_backbone(img_channel, img_height, img_width, leaky_relu) | |
| self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden) | |
| self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True) | |
| self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True) | |
| self.dense = nn.Linear(2 * rnn_hidden, num_class) | |
| def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu): | |
| assert img_height % 16 == 0 | |
| assert img_width % 4 == 0 | |
| channels = [img_channel, 64, 128, 256, 256, 512, 512, 512] | |
| kernel_sizes = [3, 3, 3, 3, 3, 3, 2] | |
| strides = [1, 1, 1, 1, 1, 1, 1] | |
| paddings = [1, 1, 1, 1, 1, 1, 0] | |
| cnn = nn.Sequential() | |
| def conv_relu(i, batch_norm=False): | |
| # shape of input: (batch, input_channel, height, width) | |
| input_channel = channels[i] | |
| output_channel = channels[i+1] | |
| cnn.add_module( | |
| f'conv{i}', | |
| nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i]) | |
| ) | |
| if batch_norm: | |
| cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel)) | |
| relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True) | |
| cnn.add_module(f'relu{i}', relu) | |
| # size of image: (channel, height, width) = (img_channel, img_height, img_width) | |
| conv_relu(0) | |
| cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2)) | |
| # (64, img_height // 2, img_width // 2) | |
| conv_relu(1) | |
| cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2)) | |
| # (128, img_height // 4, img_width // 4) | |
| conv_relu(2) | |
| conv_relu(3) | |
| cnn.add_module( | |
| 'pooling2', | |
| nn.MaxPool2d(kernel_size=(2, 1)) | |
| ) # (256, img_height // 8, img_width // 4) | |
| conv_relu(4, batch_norm=True) | |
| conv_relu(5, batch_norm=True) | |
| cnn.add_module( | |
| 'pooling3', | |
| nn.MaxPool2d(kernel_size=(2, 1)) | |
| ) # (512, img_height // 16, img_width // 4) | |
| conv_relu(6) # (512, img_height // 16 - 1, img_width // 4 - 1) | |
| output_channel, output_height, output_width = \ | |
| channels[-1], img_height // 16 - 1, img_width // 4 - 1 | |
| return cnn, (output_channel, output_height, output_width) | |
| def forward(self, images): | |
| # shape of images: (batch, channel, height, width) | |
| conv = self.cnn(images) | |
| batch, channel, height, width = conv.size() | |
| conv = conv.view(batch, channel * height, width) | |
| conv = conv.permute(2, 0, 1) # (width, batch, feature) | |
| seq = self.map_to_seq(conv) | |
| recurrent, _ = self.rnn1(seq) | |
| recurrent, _ = self.rnn2(recurrent) | |
| output = self.dense(recurrent) | |
| return output # shape: (seq_len, batch, num_class) | |