Alimustoofaa's picture
first commit
7ee7e3a
import torch.nn as nn
class CRNN(nn.Module):
def __init__(self, img_channel, img_height, img_width, num_class,
map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
super(CRNN, self).__init__()
self.cnn, (output_channel, output_height, output_width) = \
self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)
self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)
self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)
self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)
self.dense = nn.Linear(2 * rnn_hidden, num_class)
def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
assert img_height % 16 == 0
assert img_width % 4 == 0
channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]
kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
strides = [1, 1, 1, 1, 1, 1, 1]
paddings = [1, 1, 1, 1, 1, 1, 0]
cnn = nn.Sequential()
def conv_relu(i, batch_norm=False):
# shape of input: (batch, input_channel, height, width)
input_channel = channels[i]
output_channel = channels[i+1]
cnn.add_module(
f'conv{i}',
nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])
)
if batch_norm:
cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))
relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
cnn.add_module(f'relu{i}', relu)
# size of image: (channel, height, width) = (img_channel, img_height, img_width)
conv_relu(0)
cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))
# (64, img_height // 2, img_width // 2)
conv_relu(1)
cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))
# (128, img_height // 4, img_width // 4)
conv_relu(2)
conv_relu(3)
cnn.add_module(
'pooling2',
nn.MaxPool2d(kernel_size=(2, 1))
) # (256, img_height // 8, img_width // 4)
conv_relu(4, batch_norm=True)
conv_relu(5, batch_norm=True)
cnn.add_module(
'pooling3',
nn.MaxPool2d(kernel_size=(2, 1))
) # (512, img_height // 16, img_width // 4)
conv_relu(6) # (512, img_height // 16 - 1, img_width // 4 - 1)
output_channel, output_height, output_width = \
channels[-1], img_height // 16 - 1, img_width // 4 - 1
return cnn, (output_channel, output_height, output_width)
def forward(self, images):
# shape of images: (batch, channel, height, width)
conv = self.cnn(images)
batch, channel, height, width = conv.size()
conv = conv.view(batch, channel * height, width)
conv = conv.permute(2, 0, 1) # (width, batch, feature)
seq = self.map_to_seq(conv)
recurrent, _ = self.rnn1(seq)
recurrent, _ = self.rnn2(recurrent)
output = self.dense(recurrent)
return output # shape: (seq_len, batch, num_class)