File size: 2,799 Bytes
7ee7e3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch.nn as nn


class CRNN(nn.Module):

	def __init__(self, img_channel, img_height, img_width, num_class,
				 map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
		super(CRNN, self).__init__()

		self.cnn, (output_channel, output_height, output_width) = \
			self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)

		self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)

		self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)
		self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)

		self.dense = nn.Linear(2 * rnn_hidden, num_class)

	def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
		assert img_height % 16 == 0
		assert img_width % 4 == 0

		channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]
		kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
		strides = [1, 1, 1, 1, 1, 1, 1]
		paddings = [1, 1, 1, 1, 1, 1, 0]

		cnn = nn.Sequential()

		def conv_relu(i, batch_norm=False):
			# shape of input: (batch, input_channel, height, width)
			input_channel = channels[i]
			output_channel = channels[i+1]

			cnn.add_module(
				f'conv{i}',
				nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])
			)

			if batch_norm:
				cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))

			relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
			cnn.add_module(f'relu{i}', relu)

		# size of image: (channel, height, width) = (img_channel, img_height, img_width)
		conv_relu(0)
		cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))
		# (64, img_height // 2, img_width // 2)

		conv_relu(1)
		cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))
		# (128, img_height // 4, img_width // 4)

		conv_relu(2)
		conv_relu(3)
		cnn.add_module(
			'pooling2',
			nn.MaxPool2d(kernel_size=(2, 1))
		)  # (256, img_height // 8, img_width // 4)

		conv_relu(4, batch_norm=True)
		conv_relu(5, batch_norm=True)
		cnn.add_module(
			'pooling3',
			nn.MaxPool2d(kernel_size=(2, 1))
		)  # (512, img_height // 16, img_width // 4)

		conv_relu(6)  # (512, img_height // 16 - 1, img_width // 4 - 1)

		output_channel, output_height, output_width = \
			channels[-1], img_height // 16 - 1, img_width // 4 - 1
		return cnn, (output_channel, output_height, output_width)

	def forward(self, images):
		# shape of images: (batch, channel, height, width)

		conv = self.cnn(images)
		batch, channel, height, width = conv.size()

		conv = conv.view(batch, channel * height, width)
		conv = conv.permute(2, 0, 1)  # (width, batch, feature)
		seq = self.map_to_seq(conv)

		recurrent, _ = self.rnn1(seq)
		recurrent, _ = self.rnn2(recurrent)

		output = self.dense(recurrent)
		return output  # shape: (seq_len, batch, num_class)