Spaces:
Sleeping
Sleeping
| """ | |
| Paper: "UTRNet: High-Resolution Urdu Text Recognition In Printed Documents" presented at ICDAR 2023 | |
| Authors: Abdur Rahman, Arjun Ghosh, Chetan Arora | |
| GitHub Repository: https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition | |
| Project Website: https://abdur75648.github.io/UTRNet/ | |
| Copyright (c) 2023-present: This work is licensed under the Creative Commons Attribution-NonCommercial | |
| 4.0 International License (http://creativecommons.org/licenses/by-nc/4.0/) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class Attention(nn.Module): | |
| def __init__(self, input_size, hidden_size, num_classes, device): | |
| super(Attention, self).__init__() | |
| self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) | |
| self.hidden_size = hidden_size | |
| self.num_classes = num_classes | |
| self.generator = nn.Linear(hidden_size, num_classes) | |
| self.device = device | |
| def _char_to_onehot(self, input_char, onehot_dim=38): | |
| input_char = input_char.unsqueeze(1) | |
| batch_size = input_char.size(0) | |
| one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(self.device) | |
| one_hot = one_hot.scatter_(1, input_char, 1) | |
| return one_hot | |
| def forward(self, batch_H, text, is_train=True, batch_max_length=25): | |
| """ | |
| input: | |
| batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels] | |
| text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. | |
| output: probability distribution at each step [batch_size x num_steps x num_classes] | |
| """ | |
| batch_size = batch_H.size(0) | |
| num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. | |
| output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(self.device) | |
| hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(self.device), | |
| torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(self.device)) | |
| if is_train: | |
| for i in range(num_steps): | |
| # one-hot vectors for a i-th char. in a batch | |
| char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) | |
| # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) | |
| hidden, _ = self.attention_cell(hidden, batch_H, char_onehots) | |
| output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) | |
| probs = self.generator(output_hiddens) | |
| else: | |
| targets = torch.LongTensor(batch_size).fill_(0).to(self.device) # [GO] token | |
| probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(self.device) | |
| for i in range(num_steps): | |
| char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) | |
| hidden, _ = self.attention_cell(hidden, batch_H, char_onehots) | |
| probs_step = self.generator(hidden[0]) | |
| probs[:, i, :] = probs_step | |
| _, next_input = probs_step.max(1) | |
| targets = next_input | |
| return probs # batch_size x num_steps x num_classes | |
| class AttentionCell(nn.Module): | |
| def __init__(self, input_size, hidden_size, num_embeddings): | |
| super(AttentionCell, self).__init__() | |
| self.i2h = nn.Linear(input_size, hidden_size, bias=False) | |
| self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias | |
| self.score = nn.Linear(hidden_size, 1, bias=False) | |
| self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) | |
| self.hidden_size = hidden_size | |
| def forward(self, prev_hidden, batch_H, char_onehots): | |
| # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] | |
| batch_H_proj = self.i2h(batch_H) | |
| prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) | |
| e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 | |
| alpha = F.softmax(e, dim=1) | |
| context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel | |
| concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) | |
| cur_hidden = self.rnn(concat_context, prev_hidden) | |
| return cur_hidden, alpha | |