Spaces:
Sleeping
Sleeping
| """ | |
| Paper: "UTRNet: High-Resolution Urdu Text Recognition In Printed Documents" presented at ICDAR 2023 | |
| Authors: Abdur Rahman, Arjun Ghosh, Chetan Arora | |
| GitHub Repository: https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition | |
| Project Website: https://abdur75648.github.io/UTRNet/ | |
| Copyright (c) 2023-present: This work is licensed under the Creative Commons Attribution-NonCommercial | |
| 4.0 International License (http://creativecommons.org/licenses/by-nc/4.0/) | |
| """ | |
| import pytz | |
| import torch | |
| import numpy as np | |
| from datetime import datetime | |
| import matplotlib.pyplot as plt | |
| from torch.autograd import Variable | |
| import os,random,shutil | |
| import matplotlib.pyplot as plt | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| class CTCLabelConverter(object): | |
| """ Convert between text-label and text-index """ | |
| def __init__(self, character): | |
| # character (str): set of the possible characters. | |
| dict_character = list(character) | |
| self.dict = {} | |
| for i, char in enumerate(dict_character): | |
| # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss | |
| self.dict[char] = i + 1 | |
| self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) | |
| def encode(self, text, batch_max_length=25): | |
| """convert text-label into text-index. | |
| input: | |
| text: text labels of each image. [batch_size] | |
| batch_max_length: max length of text label in the batch. 25 by default | |
| output: | |
| text: text index for CTCLoss. [batch_size, batch_max_length] | |
| length: length of each text. [batch_size] | |
| """ | |
| length = [len(s) for s in text] | |
| # The index used for padding (=0) would not affect the CTC loss calculation. | |
| batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) | |
| for i, t in enumerate(text): | |
| text = list(t) | |
| text = [self.dict[char] for char in text] | |
| batch_text[i][:len(text)] = torch.LongTensor(text) | |
| return (batch_text, torch.IntTensor(length)) | |
| def decode(self, text_index, length): | |
| """ convert text-index into text-label. """ | |
| texts = [] | |
| for index, l in enumerate(length): | |
| t = text_index[index, :] | |
| char_list = [] | |
| for i in range(l): | |
| if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. | |
| char_list.append(self.character[t[i]]) | |
| text = ''.join(char_list) | |
| texts.append(text) | |
| return texts | |
| class CTCLabelConverterForBaiduWarpctc(object): | |
| """ Convert between text-label and text-index for baidu warpctc """ | |
| def __init__(self, character): | |
| # character (str): set of the possible characters. | |
| dict_character = list(character) | |
| self.dict = {} | |
| for i, char in enumerate(dict_character): | |
| # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss | |
| self.dict[char] = i + 1 | |
| self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) | |
| def encode(self, text, batch_max_length=25): | |
| """convert text-label into text-index. | |
| input: | |
| text: text labels of each image. [batch_size] | |
| output: | |
| text: concatenated text index for CTCLoss. | |
| [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] | |
| length: length of each text. [batch_size] | |
| """ | |
| length = [len(s) for s in text] | |
| text = ''.join(text) | |
| text = [self.dict[char] for char in text] | |
| return (torch.IntTensor(text), torch.IntTensor(length)) | |
| def decode(self, text_index, length): | |
| """ convert text-index into text-label. """ | |
| texts = [] | |
| index = 0 | |
| for l in length: | |
| t = text_index[index:index + l] | |
| char_list = [] | |
| for i in range(l): | |
| if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. | |
| char_list.append(self.character[t[i]]) | |
| text = ''.join(char_list) | |
| texts.append(text) | |
| index += l | |
| return texts | |
| class AttnLabelConverter(object): | |
| """ Convert between text-label and text-index """ | |
| def __init__(self, character): | |
| # character (str): set of the possible characters. | |
| # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. | |
| list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] | |
| list_character = list(character) | |
| self.character = list_token + list_character | |
| self.dict = {} | |
| for i, char in enumerate(self.character): | |
| # print(i, char) | |
| self.dict[char] = i | |
| def encode(self, text, batch_max_length=25): | |
| """ convert text-label into text-index. | |
| input: | |
| text: text labels of each image. [batch_size] | |
| batch_max_length: max length of text label in the batch. 25 by default | |
| output: | |
| text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. | |
| text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. | |
| length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] | |
| """ | |
| length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. | |
| # batch_max_length = max(length) # this is not allowed for multi-gpu setting | |
| batch_max_length += 1 | |
| # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. | |
| batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) | |
| for i, t in enumerate(text): | |
| text = list(t) | |
| text.append('[s]') | |
| try: | |
| text = [self.dict[char] for char in text] | |
| except KeyError as e: | |
| continue | |
| batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token | |
| return (batch_text, torch.IntTensor(length)) | |
| def decode(self, text_index, length): | |
| """ convert text-index into text-label. """ | |
| texts = [] | |
| for index, l in enumerate(length): | |
| text = ''.join([self.character[i] for i in text_index[index, :]]) | |
| texts.append(text) | |
| return texts | |
| def imshow(img, title,batch_size=1): | |
| std_correction = np.asarray([0.229, 0.224, 0.225]).reshape(3, 1, 1) | |
| mean_correction = np.asarray([0.485, 0.456, 0.406]).reshape(3, 1, 1) | |
| npimg = np.multiply(img.numpy(), std_correction) + mean_correction | |
| plt.figure(figsize = (batch_size * 4, 4)) | |
| plt.axis("off") | |
| plt.imshow(np.transpose(npimg, (1, 2, 0))) | |
| plt.title(title) | |
| plt.show() | |
| class Averager(object): | |
| """Compute average for torch.Tensor, used for loss average.""" | |
| def __init__(self): | |
| self.reset() | |
| def add(self, v): | |
| count = v.data.numel() | |
| v = v.data.sum() | |
| self.n_count += count | |
| self.sum += v | |
| def reset(self): | |
| self.n_count = 0 | |
| self.sum = 0 | |
| def val(self): | |
| res = 0 | |
| if self.n_count != 0: | |
| res = self.sum / float(self.n_count) | |
| return res | |
| class Logger(object): | |
| """For logging while training""" | |
| def __init__(self, path): | |
| self.logFile = path | |
| datetime_now = str(datetime.now(pytz.timezone('Asia/Kolkata')).strftime("%Y-%m-%d_%H-%M-%S")) | |
| with open(self.logFile,"w",encoding="utf-8") as f: | |
| f.write("Logging at @ " + str(datetime_now) + "\n") | |
| def log(self,*input): | |
| message = "" | |
| for x in input: | |
| message+=str(x) + " " | |
| message = message.strip() | |
| print(message) | |
| with open(self.logFile,"a",encoding="utf-8") as f: | |
| f.write(str(message)+"\n") | |
| def allign_two_strings(x:str, y:str, pxy:int=1, pgap:int=1): | |
| """ | |
| Source: https://www.geeksforgeeks.org/sequence-alignment-problem/ | |
| """ | |
| i = 0 | |
| j = 0 | |
| m = len(x) | |
| n = len(y) | |
| dp = np.zeros([m+1,n+1], dtype=int) | |
| dp[0:(m+1),0] = [ i * pgap for i in range(m+1)] | |
| dp[0,0:(n+1)] = [ i * pgap for i in range(n+1)] | |
| i = 1 | |
| while i <= m: | |
| j = 1 | |
| while j <= n: | |
| if x[i - 1] == y[j - 1]: | |
| dp[i][j] = dp[i - 1][j - 1] | |
| else: | |
| dp[i][j] = min(dp[i - 1][j - 1] + pxy, | |
| dp[i - 1][j] + pgap, | |
| dp[i][j - 1] + pgap) | |
| j += 1 | |
| i += 1 | |
| l = n + m | |
| i = m | |
| j = n | |
| xpos = l | |
| ypos = l | |
| xans = np.zeros(l+1, dtype=int) | |
| yans = np.zeros(l+1, dtype=int) | |
| while not (i == 0 or j == 0): | |
| #print(f"i: {i}, j: {j}") | |
| if x[i - 1] == y[j - 1]: | |
| xans[xpos] = ord(x[i - 1]) | |
| yans[ypos] = ord(y[j - 1]) | |
| xpos -= 1 | |
| ypos -= 1 | |
| i -= 1 | |
| j -= 1 | |
| elif (dp[i - 1][j - 1] + pxy) == dp[i][j]: | |
| xans[xpos] = ord(x[i - 1]) | |
| yans[ypos] = ord(y[j - 1]) | |
| xpos -= 1 | |
| ypos -= 1 | |
| i -= 1 | |
| j -= 1 | |
| elif (dp[i - 1][j] + pgap) == dp[i][j]: | |
| xans[xpos] = ord(x[i - 1]) | |
| yans[ypos] = ord('_') | |
| xpos -= 1 | |
| ypos -= 1 | |
| i -= 1 | |
| elif (dp[i][j - 1] + pgap) == dp[i][j]: | |
| xans[xpos] = ord('_') | |
| yans[ypos] = ord(y[j - 1]) | |
| xpos -= 1 | |
| ypos -= 1 | |
| j -= 1 | |
| while xpos > 0: | |
| if i > 0: | |
| i -= 1 | |
| xans[xpos] = ord(x[i]) | |
| xpos -= 1 | |
| else: | |
| xans[xpos] = ord('_') | |
| xpos -= 1 | |
| while ypos > 0: | |
| if j > 0: | |
| j -= 1 | |
| yans[ypos] = ord(y[j]) | |
| ypos -= 1 | |
| else: | |
| yans[ypos] = ord('_') | |
| ypos -= 1 | |
| id = 1 | |
| i = l | |
| while i >= 1: | |
| if (chr(yans[i]) == '_') and chr(xans[i]) == '_': | |
| id = i + 1 | |
| break | |
| i -= 1 | |
| i = id | |
| x_seq = "" | |
| while i <= l: | |
| x_seq += chr(xans[i]) | |
| i += 1 | |
| # Y | |
| i = id | |
| y_seq = "" | |
| while i <= l: | |
| y_seq += chr(yans[i]) | |
| i += 1 | |
| return x_seq, y_seq | |
| # Function to count the number of trainable parameters in a model in "Millions" | |
| def count_parameters(model,precision=2): | |
| return (round(sum(p.numel() for p in model.parameters() if p.requires_grad) / 10.**6, precision)) | |
| ''' | |
| # Code for counting the number of FLOPs in the CNN backbone during inference | |
| Source - https://github.com/fdbtrs/ElasticFace/blob/main/utils/countFLOPS.py | |
| ''' | |
| def count_model_flops(model,in_channels=1, input_res=[32, 400], multiply_adds=True): | |
| list_conv = [] | |
| def conv_hook(self, input, output): | |
| batch_size, input_channels, input_height, input_width = input[0].size() | |
| output_channels, output_height, output_width = output[0].size() | |
| kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) | |
| bias_ops = 1 if self.bias is not None else 0 | |
| params = output_channels * (kernel_ops + bias_ops) | |
| flops = (kernel_ops * ( | |
| 2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size | |
| list_conv.append(flops) | |
| list_linear = [] | |
| def linear_hook(self, input, output): | |
| batch_size = input[0].size(0) if input[0].dim() == 2 else 1 | |
| weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) | |
| if self.bias is not None: | |
| bias_ops = self.bias.nelement() if self.bias.nelement() else 0 | |
| flops = batch_size * (weight_ops + bias_ops) | |
| else: | |
| flops = batch_size * weight_ops | |
| list_linear.append(flops) | |
| list_bn = [] | |
| def bn_hook(self, input, output): | |
| list_bn.append(input[0].nelement() * 2) | |
| list_relu = [] | |
| def relu_hook(self, input, output): | |
| list_relu.append(input[0].nelement()) | |
| list_pooling = [] | |
| def pooling_hook(self, input, output): | |
| batch_size, input_channels, input_height, input_width = input[0].size() | |
| output_channels, output_height, output_width = output[0].size() | |
| # If kernel_size is a tuple type, computer ops as product of elements or else if it is int type, compute ops as square of kernel_size | |
| kernel_ops = self.kernel_size[0] * self.kernel_size[1] if isinstance(self.kernel_size, tuple) else self.kernel_size * self.kernel_size | |
| bias_ops = 0 | |
| params = 0 | |
| flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size | |
| list_pooling.append(flops) | |
| def dropout_hook(self, input, output): | |
| # calculate the number of operations for a dropout function by assuming that each operation involves one comparison and one multiplication | |
| batch_size, input_channels, input_height, input_width = input[0].size() | |
| list_conv.append(2*batch_size*input_channels*input_height*input_width) | |
| def sigmoid_hook(self,input,output): | |
| # calculate the number of operations for a sigmoid function by assuming that each operation involves two multiplications and one addition | |
| batch_size, input_channels, input_height, input_width = input[0].size() | |
| list_conv.append(3*batch_size*input_channels*input_height*input_width) | |
| def upsample_hook(self, input, output): | |
| batch_size, input_channels, input_height, input_width = input[0].size() | |
| output_channels, output_height, output_width = output[0].size() | |
| kernel_ops = self.scale_factor * self.scale_factor # * (self.in_channels / self.groups) | |
| flops = (kernel_ops * ( | |
| 2 if multiply_adds else 1)) * output_channels * output_height * output_width * batch_size | |
| list_conv.append(flops) | |
| handles = [] | |
| def foo(net): | |
| childrens = list(net.children()) | |
| if not childrens: | |
| if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d): | |
| handles.append(net.register_forward_hook(conv_hook)) | |
| elif isinstance(net, torch.nn.Linear): | |
| handles.append(net.register_forward_hook(linear_hook)) | |
| elif isinstance(net, torch.nn.BatchNorm2d) or isinstance(net, torch.nn.BatchNorm1d): | |
| handles.append(net.register_forward_hook(bn_hook)) | |
| elif isinstance(net, torch.nn.ReLU) or isinstance(net, torch.nn.PReLU): | |
| handles.append(net.register_forward_hook(relu_hook)) | |
| elif isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): | |
| handles.append(net.register_forward_hook(pooling_hook)) | |
| elif isinstance(net, torch.nn.Dropout): | |
| handles.append(net.register_forward_hook(dropout_hook)) | |
| elif isinstance(net,torch.nn.Upsample): | |
| handles.append(net.register_forward_hook(upsample_hook)) | |
| elif isinstance(net,torch.nn.Sigmoid): | |
| handles.append(net.register_forward_hook(sigmoid_hook)) | |
| else: | |
| print("warning" + str(net)) | |
| return | |
| for c in childrens: | |
| foo(c) | |
| model.eval() | |
| foo(model) | |
| input = Variable(torch.rand(in_channels, input_res[1], input_res[0]).unsqueeze(0), requires_grad=True) | |
| out = model(input) | |
| total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling)) | |
| for h in handles: | |
| h.remove() | |
| model.train() | |
| def flops_to_string(flops, units='MFLOPS', precision=4): | |
| if units == 'GFLOPS': | |
| return str(round(flops / 10.**9, precision)) + ' ' + units | |
| elif units == 'MFLOPS': | |
| return str(round(flops / 10.**6, precision)) + ' ' + units | |
| elif units == 'KFLOPS': | |
| return str(round(flops / 10.**3, precision)) + ' ' + units | |
| else: | |
| return str(flops) + ' FLOPS' | |
| return flops_to_string(total_flops) | |
| def draw_feature_map(visual_feature,vis_dir,num_channel=10): | |
| """draws feature maps for the given visual features | |
| Args: | |
| visual_feature (Tensor): Shape (C, H, W) | |
| vis_dir (String): Directory to save the feature maps | |
| """ | |
| if os.path.exists(vis_dir): | |
| shutil.rmtree(vis_dir) | |
| os.makedirs(vis_dir) | |
| # Save visual_feature from num_channel random channels for visualization | |
| for i in range(num_channel): | |
| random_channel = random.randint(0, visual_feature.shape[1]-1) | |
| visual_feature_for_visualization = visual_feature[0, random_channel, :, :].detach().cpu().numpy() | |
| # Horizontal flip | |
| visual_feature_for_visualization = visual_feature_for_visualization[:,::-1] | |
| # Normalize | |
| visual_feature_for_visualization = (visual_feature_for_visualization - visual_feature_for_visualization.min()) / (visual_feature_for_visualization.max() - visual_feature_for_visualization.min()) | |
| # Draw heatmap | |
| plt.imshow(visual_feature_for_visualization, cmap='gray', interpolation='nearest') | |
| plt.axis("off") | |
| plt.savefig(os.path.join(vis_dir, "channel_{}.png".format(random_channel)), bbox_inches='tight', pad_inches=0) |