File size: 7,568 Bytes
77f8d5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
from .feature_extraction import HRNet_FeatureExtractor
from .sequence_modeling import BidirectionalLSTM
from .dropout_layer import dropout_layer
from .prediction import Attention
import torch.nn as nn
# Other CNN Architectures
from .feature_extraction import DenseNet_FeatureExtractor, InceptionUNet_FeatureExtractor
from .feature_extraction import RCNN_FeatureExtractor, ResNet_FeatureExtractor
from .feature_extraction import ResUnet_FeatureExtractor, AttnUNet_FeatureExtractor
from .feature_extraction import UNet_FeatureExtractor, UNetPlusPlus_FeatureExtractor
from .feature_extraction import VGG_FeatureExtractor
# Other sequential models
from .sequence_modeling import LSTM, GRU, MDLSTM
class Text_recognization_model(nn.Module):
""" The constractor init the struture of the model """
def __init__(self, opt):
super(Text_recognization_model, self).__init__()
# opt is the configration of the model
self.opt = opt
# The model consist of three stages
# FeatureExtraction, SequenceModeling and Prediction
self.stages = {'Feat': opt.FeatureExtraction,
'Seq': opt.SequenceModeling,
'Pred': opt.Prediction}
""" FeatureExtraction """
# High-Resolution Network, it maintains high-resolution feature maps
if opt.FeatureExtraction == 'HRNet':
self.FeatureExtraction = HRNet_FeatureExtractor(opt.input_channel, opt.output_channel )
elif opt.FeatureExtraction == 'Densenet':
self.FeatureExtraction = DenseNet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'InceptionUnet':
self.FeatureExtraction = InceptionUNet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'RCNN':
self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'ResNet':
self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'ResUnet':
self.FeatureExtraction = ResUnet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'AttnUNet':
self.FeatureExtraction = AttnUNet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'UNet':
self.FeatureExtraction = UNet_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'UnetPlusPlus':
self.FeatureExtraction = UNetPlusPlus_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'VGG':
self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
else:
raise Exception('No FeatureExtraction module specified')
self.FeatureExtraction_output = opt.output_channel
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
"""
Temporal Dropout
"""
self.dropout1 = dropout_layer(opt.device)
self.dropout2 = dropout_layer(opt.device)
self.dropout3 = dropout_layer(opt.device)
self.dropout4 = dropout_layer(opt.device)
self.dropout5 = dropout_layer(opt.device)
""" Sequence modeling"""
if opt.SequenceModeling == 'LSTM':
self.SequenceModeling = LSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size)
elif opt.SequenceModeling == 'GRU':
self.SequenceModeling = GRU(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size)
elif opt.SequenceModeling == 'MDLSTM':
self.SequenceModeling = MDLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size)
elif opt.SequenceModeling == 'BiLSTM':
self.SequenceModeling = BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size)
# Double BiLSTM
elif opt.SequenceModeling == 'DBiLSTM':
self.SequenceModeling = nn.Sequential(
BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
else:
raise Exception('No Sequence Modeling module specified')
self.SequenceModeling_output = opt.hidden_size
""" Prediction """
if opt.Prediction == 'CTC':
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
elif opt.Prediction == 'Attn':
self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class, opt.device)
else:
raise Exception('Prediction is neither CTC or Attn')
def forward(self, input, text=None, is_train=True):
""" Feature extraction stage """
### Pass input to the feature extraction network ###
visual_feature = self.FeatureExtraction(input)
# print(visual_feature.shape) # [32, 32, 32, 400] #HRNet, [32, 512, 32, 400] #UNet
### Then make pooling ###
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
# print(visual_feature.shape) # [32, 400, 32, 1] #HRNet, [32, 400, 512, 1] #UNet
### Remove the columb 3 Ex=> [32,400,32,1] will be [32,400,32] ###
visual_feature = visual_feature.squeeze(3)
# print(visual_feature.shape) # [32, 400, 32] #HRNet, [32, 400, 512] #UNet
""" Temporal Dropout + Sequence modeling stage """
# contextual_feature = self.SequenceModeling(visual_feature) ##### Without temporal dropout
if (self.training):
visual_feature_after_dropout1 = self.dropout1(visual_feature)
contextual_feature = self.SequenceModeling(visual_feature_after_dropout1)
else :
# Inference Phase, make multiple dropout, and take the average of them, this is called Monte Carlo Dropout
visual_feature_after_dropout1 = self.dropout1(visual_feature)
visual_feature_after_dropout2 = self.dropout2(visual_feature)
visual_feature_after_dropout3 = self.dropout3(visual_feature)
visual_feature_after_dropout4 = self.dropout4(visual_feature)
visual_feature_after_dropout5 = self.dropout5(visual_feature)
contextual_feature1 = self.SequenceModeling(visual_feature_after_dropout1)
contextual_feature2 = self.SequenceModeling(visual_feature_after_dropout2)
contextual_feature3 = self.SequenceModeling(visual_feature_after_dropout3)
contextual_feature4 = self.SequenceModeling(visual_feature_after_dropout4)
contextual_feature5 = self.SequenceModeling(visual_feature_after_dropout5)
contextual_feature = ( (contextual_feature1).add ((contextual_feature2).add(((contextual_feature3).add(((contextual_feature4).add(contextual_feature5)))))) ) * (1/5)
""" Prediction stage """
if self.stages['Pred'] == 'CTC':
prediction = self.Prediction(contextual_feature.contiguous())
else:
if text is None:
raise Exception('Input text (for prediction) to model is None')
text = text.to(self.opt.device)
prediction = self.Prediction(contextual_feature, text, is_train, batch_max_length=self.opt.batch_max_length)
return prediction
|