triumphh77's picture
Upload 13 files
f9a156f verified
import torch
import torch.nn as nn
class CRNN(nn.Module):
def __init__(self, img_channel, img_height, img_width, num_class,
map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
super(CRNN, self).__init__()
self.cnn, (output_channel, output_height, output_width) = \
self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)
self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)
self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True, batch_first=True)
self.rnn2 = nn.LSTM(rnn_hidden * 2, rnn_hidden, bidirectional=True, batch_first=True)
self.dense = nn.Linear(rnn_hidden * 2, num_class)
def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
assert img_height % 16 == 0
assert img_width % 4 == 0
channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]
kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
strides = [1, 1, 1, 1, 1, 1, 1]
paddings = [1, 1, 1, 1, 1, 1, 0]
cnn = nn.Sequential()
def conv_relu(i, batch_normalization=False):
n_in = channels[i]
n_out = channels[i+1]
cnn.add_module(f'conv{i}', nn.Conv2d(n_in, n_out, kernel_sizes[i], strides[i], paddings[i]))
if batch_normalization:
cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(n_out))
if leaky_relu:
cnn.add_module(f'relu{i}', nn.LeakyReLU(0.2, inplace=True))
else:
cnn.add_module(f'relu{i}', nn.ReLU(inplace=True))
conv_relu(0)
cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2)) # 64x16x64
conv_relu(1)
cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2)) # 128x8x32
conv_relu(2, True)
conv_relu(3)
cnn.add_module('pooling2', nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))) # 256x4x33
conv_relu(4, True)
conv_relu(5)
cnn.add_module('pooling3', nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))) # 512x2x34
conv_relu(6, True) # 512x1x33
output_channel, output_height, output_width = channels[-1], img_height // 16 - 1, img_width // 4 + 1
return cnn, (output_channel, output_height, output_width)
def forward(self, images):
# shape of images: (batch, channel, height, width)
conv = self.cnn(images)
batch, channel, height, width = conv.size()
conv = conv.view(batch, channel * height, width)
conv = conv.permute(0, 2, 1) # (batch, width, channel*height)
seq = self.map_to_seq(conv)
recurrent, _ = self.rnn1(seq)
recurrent, _ = self.rnn2(recurrent)
output = self.dense(recurrent)
# Log softmax for CTC loss
# Note: PyTorch's CTCLoss expects inputs of shape (input_length, batch_size, num_classes)
# So we permute it if we are returning it for CTC loss calculation directly
return output.log_softmax(2)
if __name__ == '__main__':
# Test model
dummy_input = torch.randn(1, 1, 32, 1024)
model = CRNN(img_channel=1, img_height=32, img_width=1024, num_class=80)
output = model(dummy_input)
print(f"Output shape: {output.shape}") # Expected: (1, 33, 80)