diff --git a/HybridViT/beam.py b/HybridViT/beam.py new file mode 100644 index 0000000000000000000000000000000000000000..c91a9d64105d6b2eebee9b4fd10542c2cc8d8ee1 --- /dev/null +++ b/HybridViT/beam.py @@ -0,0 +1,131 @@ +import torch +from typing import List +from einops import rearrange, repeat +from typing import Optional + +class Hypothesis: + seq: List[int] + score: float + attn_weights: List[float] + + def __init__( + self, + seq_tensor: torch.LongTensor, + score: float, + weights: Optional[torch.FloatTensor] = None + ) -> None: + raw_seq = seq_tensor.tolist() + + self.seq = raw_seq + self.score = score + if weights: + self.attn_weights = weights.tolist() + assert len(self.attn_weights) == len(self.seq) + else: + self.attn_weights = None + + def __len__(self): + if len(self.seq) != 0: + return len(self.seq) + else: + return 1 + + def __str__(self): + return f"seq: {self.seq}, score: {self.score}, weight: {self.attn_weights}" + + +class Beam: + def __init__(self, + start_w=1, + stop_w=2, + ignore_w=0, + max_len=150, + viz_attn=False, + device='cuda' + ): + self.stop_w = stop_w + self.start_w = start_w + + self.hypotheses = torch.full( + (1, max_len + 2), + fill_value=ignore_w, + dtype=torch.long, + device=device, + ) + if viz_attn: + self.hyp_alpha = torch.ones(1, max_len + 2, dtype=torch.float, device=device) + + self.hypotheses[:, 0] = start_w + self.hyp_scores = torch.zeros(1, dtype=torch.float, device=device) + self.completed_hypotheses: List[Hypothesis] = [] + self.device = device + self.viz_attn = viz_attn + + def advance(self, next_log_probs, step, beam_size): + vocab_size = next_log_probs.shape[1] + live_hyp_num = beam_size - len(self.completed_hypotheses) + exp_hyp_scores = repeat(self.hyp_scores, "b -> b e", e=vocab_size) + continuous_hyp_scores = rearrange(exp_hyp_scores + next_log_probs, "b e -> (b e)") + top_cand_hyp_scores, top_cand_hyp_pos = torch.topk( + continuous_hyp_scores, k=live_hyp_num + ) + + prev_hyp_ids = top_cand_hyp_pos // vocab_size + hyp_word_ids = top_cand_hyp_pos % vocab_size + + step += 1 + new_hypotheses = [] + new_hyp_scores = [] + + for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip( + prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores + ): + cand_new_hyp_score = cand_new_hyp_score.detach().item() + self.hypotheses[prev_hyp_id, step] = hyp_word_id + + if hyp_word_id == self.stop_w: + self.completed_hypotheses.append( + Hypothesis( + seq_tensor=self.hypotheses[prev_hyp_id, 1:step+1] + .detach() + .clone(), # remove START_W at first + score=cand_new_hyp_score, + ) + ) + else: + new_hypotheses.append(self.hypotheses[prev_hyp_id].detach().clone()) + new_hyp_scores.append(cand_new_hyp_score) + + return new_hypotheses, new_hyp_scores + + def get_incomplete_inds(self, hyp_word_ids): + return [ind for ind, next_word in enumerate(hyp_word_ids) if + next_word != self.stop_w] + + def get_complete_inds(self, hyp_word_ids, incomplete_inds): + return list(set(range(len(hyp_word_ids))) - set(incomplete_inds)) + + def set_current_state(self, hypotheses): + "Set the outputs for the current timestep." + self.hypotheses = torch.stack(hypotheses, dim=0) + return + + def set_current_score(self, hyp_scores): + "Set the scores for the current timestep." + self.hyp_scores = torch.tensor( + hyp_scores, dtype=torch.float, device=self.device + ) + return + + def done(self, beam_size): + return len(self.completed_hypotheses) == beam_size + + def set_hypothesis(self): + if len(self.completed_hypotheses) == 0: + self.completed_hypotheses.append( + Hypothesis( + seq_tensor=self.hypotheses[0, 1:].detach().clone(), + score=self.hyp_scores[0].detach().item(), + ) + ) + return \ No newline at end of file diff --git a/HybridViT/helper.py b/HybridViT/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..ba5cec3e1b2ca917cadbd914fbe45c0a6095f6b2 --- /dev/null +++ b/HybridViT/helper.py @@ -0,0 +1,182 @@ +import torch +import random +import numpy as np +from PIL import Image +from typing import Dict +import torch.nn.functional as F +import numpy as np +from PIL import Image +import cv2 +import math +import albumentations as alb +from albumentations.pytorch.transforms import ToTensorV2 +from collections import OrderedDict +from itertools import repeat +import collections.abc + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + +to_3tuple = _ntuple(3) + +def clean_state_dict(state_dict): + # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training + cleaned_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] if k.startswith('module.') else k + cleaned_state_dict[name] = v + return cleaned_state_dict + + +def math_transform(mean, std, is_gray: bool): + test_transform = [] + normalize = [ + alb.CLAHE(clip_limit=2, tile_grid_size=(2, 2), always_apply=True), + alb.Normalize(to_3tuple(mean), to_3tuple(std)), + ToTensorV2() + ] + if is_gray: + test_transform += [alb.ToGray(always_apply=True)] + test_transform += normalize + + test_transform = alb.Compose([*test_transform]) + return test_transform + + +def pad(img: Image.Image, divable=32): + """Pad an Image to the next full divisible value of `divable`. Also normalizes the PIL.image and invert if needed. + + Args: + img (PIL.Image): input PIL.image + divable (int, optional): . Defaults to 32. + + Returns: + PIL.Image + """ + data = np.array(img.convert('LA')) + + data = (data-data.min())/(data.max()-data.min())*255 + if data[..., 0].mean() > 128: + gray = 255*(data[..., 0] < 128).astype(np.uint8) # To invert the text to white + else: + gray = 255*(data[..., 0] > 128).astype(np.uint8) + data[..., 0] = 255-data[..., 0] + + coords = cv2.findNonZero(gray) # Find all non-zero points (text) + a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box + rect = data[b:b+h, a:a+w] + + if rect[..., -1].var() == 0: + im = Image.fromarray((rect[..., 0]).astype(np.uint8)).convert('L') + else: + im = Image.fromarray((255-rect[..., -1]).astype(np.uint8)).convert('L') + dims = [] + + for x in [w, h]: + div, mod = divmod(x, divable) + dims.append(divable*(div + (1 if mod > 0 else 0))) + + padded = Image.new('L', dims, 255) + padded.paste(im, im.getbbox()) + + return padded + +def get_divisible_size(ori_h, ori_w, max_dimension=None, scale_factor=32): + new_h, new_w = ori_h, ori_w + if ori_h % scale_factor: + new_h = math.ceil(ori_h/scale_factor)*scale_factor + if new_h > max_dimension[0]: + new_h = math.floor(ori_h/scale_factor)*scale_factor + + if ori_w % scale_factor: + new_w = math.ceil(ori_w/scale_factor)*scale_factor + if new_w > max_dimension[1]: + new_w = math.floor(ori_w/scale_factor)*scale_factor + + return int(new_h),int(new_w) + +def minmax_size(img, max_dimensions=None, min_dimensions=None, is_gray=True): + if max_dimensions is not None: + ratios = [a/b for a, b in zip(list(img.size)[::-1], max_dimensions)] + if any([r > 1 for r in ratios]): + size = np.array(img.size)/max(ratios) + new_h, new_w = get_divisible_size(size[1], size[0], max_dimensions) + img = img.resize((new_w, new_h), Image.LANCZOS) + + if min_dimensions is not None: + ratios = [a/b for a, b in zip(list(img.size)[::-1], min_dimensions)] + if any([r < 1 for r in ratios]): + new_h, new_w = img.size[1] / min(ratios), img.size[0] / min(ratios) + new_h, new_w = get_divisible_size(new_h, new_w, max_dimensions) + if is_gray: + MODE = 'L' + BACKGROUND = 255 + padded_im = Image.new(MODE, (new_w, new_h), BACKGROUND) + padded_im.paste(img, img.getbbox()) + img = padded_im + + return img + +def resize(resizer, img: Image.Image, opt: Dict): + # for math recognition problem image alway in grayscale mode + img = img.convert('L') + assert isinstance(opt, Dict) + assert "imgH" in opt + assert "imgW" in opt + expected_H = opt['imgH'] + + if expected_H is None: + max_dimensions = opt['max_dimension'] #can be bigger than max dim in training set + min_dimensions = opt['min_dimension'] + #equal to min dim in trainign set + test_transform = math_transform(opt['mean'], opt['std'], not opt['rgb']) + try: + new_img = minmax_size(pad(img) if opt['pad'] else img, max_dimensions, min_dimensions, not opt['rgb']) + + if not resizer: + new_img = np.asarray(new_img.convert('RGB')).astype('uint8') + new_img = test_transform(image=new_img)['image'] + if not opt['rgb']: new_img = new_img[:1] + new_img = new_img.unsqueeze(0) + new_img = new_img.float() + else: + with torch.no_grad(): + input_image = new_img.convert('RGB').copy() + r, w, h = 1, input_image.size[0], input_image.size[1] + for i in range(20): + h = int(h * r) + new_img = pad(minmax_size(input_image.resize((w, h), Image.BILINEAR if r > 1 else Image.LANCZOS), + max_dimensions, + min_dimensions, + not opt['rgb'] + )) + t = test_transform(image=np.array(new_img.convert('RGB')).astype('uint8'))['image'] + if not opt['rgb']: t = t[:1] + t = t.unsqueeze(0) + t = t.float() + w = (resizer(t.to(opt['device'])).argmax(-1).item()+1)*opt['min_dimension'][1] + + if (w == new_img.size[0]): + break + + r = w/new_img.size[0] + + new_img = t + except ValueError as e: + print('Error:', e) + new_img = np.asarray(img.convert('RGB')).astype('uint8') + assert len(new_img.shape) == 3 and new_img.shape[2] == 3 + new_img = test_transform(image=new_img)['image'] + if not opt['rgb']: new_img = new_img[:1] + new_img = new_img.unsqueeze(0) + h, w = new_img.shape[2:] + new_img = F.pad(new_img, (0, max_dimensions[1]-w, 0, max_dimensions[0]-h), value=1) + + assert len(new_img.shape) == 4, f'{new_img.shape}' + return new_img diff --git a/HybridViT/module/component/common/__init__.py b/HybridViT/module/component/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a2e01462a4a223e12014ec012eb119fc6a6cd84 --- /dev/null +++ b/HybridViT/module/component/common/__init__.py @@ -0,0 +1,6 @@ +from .conv import * +from .droppath import * +from .gated_sum import * +from .maxout import * +from .postional_encoding import * +from .mae_posembed import * \ No newline at end of file diff --git a/HybridViT/module/component/common/conv.py b/HybridViT/module/component/common/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d88622aa7c435a61e14770f75464feaf7420b4 --- /dev/null +++ b/HybridViT/module/component/common/conv.py @@ -0,0 +1,148 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional +import warnings + +__all__ = ['ConvMLP', 'ConvModule'] + +class LayerNorm2d(nn.LayerNorm): + """ LayerNorm for channels of '2D' spatial BCHW tensors """ + def __init__(self, num_channels): + super().__init__(num_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + +class DepthwiseSeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + super(DepthwiseSeparableConv2d, self).__init__() + self.depthwise = nn.Conv2d( + in_channels, + in_channels, + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + bias=bias, + groups=in_channels, + ) + self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups, bias=bias) + + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) + return out + +class ConvMLP(nn.Module): + def __init__(self, in_channels, out_channels=None, hidden_channels=None, drop=0.25): + super().__init__() + out_channels = in_channels or out_channels + hidden_channels = in_channels or hidden_channels + self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=True) + self.norm = LayerNorm2d(hidden_channels) + self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=True) + self.act = nn.ReLU() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x + +class ConvModule(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias='auto', + conv_layer:Optional[nn.Module]=nn.Conv2d, + norm_layer:Optional[nn.Module]=nn.BatchNorm2d, + act_layer:Optional[nn.Module]=nn.ReLU, + inplace=True, + with_spectral_norm=False, + padding_mode='zeros', + order=('conv', 'norm', 'act') + ): + official_padding_mode = ['zeros', 'circular'] + nonofficial_padding_mode = dict(zero=nn.ZeroPad2d, reflect=nn.ReflectionPad2d, replicate=nn.ReplicationPad2d) + self.with_spectral_norm = with_spectral_norm + self.with_explicit_padding = padding_mode not in official_padding_mode + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == set(['conv', 'norm', 'act']) + + self.with_norm = norm_layer is not None + self.with_act = act_layer is not None + + if bias == 'auto': + bias = not self.with_norm + self.with_bias = bias + + if self.with_norm and self.with_bias: + warnings.warn('ConvModule has norm and bias at the same time') + + if self.with_explicit_padding: + assert padding_mode in list(nonofficial_padding_mode), "Not implemented padding algorithm" + self.padding_layer = nonofficial_padding_mode[padding_mode] + + # reset padding to 0 for conv module + conv_padding = 0 if self.with_explicit_padding else padding + + self.conv = conv_layer( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=conv_padding, + dilation=dilation, + groups=groups, + bias=bias + ) + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + # build normalization layers + if self.with_norm: + # norm layer is after conv layer + if order.index('norm') > order.index('conv'): + norm_channels = out_channels + else: + norm_channels = in_channels + self.norm = norm_layer(norm_channels) + + if self.with_act: + if act_layer not in [nn.Tanh, nn.PReLU, nn.Sigmoid]: + self.act = act_layer() + else: + self.act = act_layer(inplace=inplace) + + def forward(self, x, activate=True, norm=True): + for layer in self.order: + if layer == 'conv': + if self.with_explicit_padding: + x = self.padding_layer(x) + x = self.conv(x) + elif layer == 'norm' and norm and self.with_norm: + x = self.norm(x) + elif layer == 'act' and activate and self.with_act: + x = self.act(x) + return x \ No newline at end of file diff --git a/HybridViT/module/component/common/droppath.py b/HybridViT/module/component/common/droppath.py new file mode 100644 index 0000000000000000000000000000000000000000..e54a07c886136353a1f0a47c6c96962f8a28422c --- /dev/null +++ b/HybridViT/module/component/common/droppath.py @@ -0,0 +1,36 @@ +from torch import nn + +__all__ = ['DropPath'] + + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' \ No newline at end of file diff --git a/HybridViT/module/component/common/gated_sum.py b/HybridViT/module/component/common/gated_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..387ef11bbe1f1d90a91a5a44fc4c517235d6add7 --- /dev/null +++ b/HybridViT/module/component/common/gated_sum.py @@ -0,0 +1,36 @@ +import torch + + +class GatedSum(torch.nn.Module): + """ + This `Module` represents a gated sum of two tensors `a` and `b`. Specifically: + ``` + f = activation(W [a; b]) + out = f * a + (1 - f) * b + ``` + # Parameters + input_dim : `int`, required + The dimensionality of the input. We assume the input have shape `(..., input_dim)`. + activation : `Activation`, optional (default = `torch.nn.Sigmoid()`) + The activation function to use. + """ + + def __init__(self, input_dim: int, activation = torch.nn.Sigmoid()) -> None: + super().__init__() + self.input_dim = input_dim + self._gate = torch.nn.Linear(input_dim * 2, 1) + self._activation = activation + + def get_input_dim(self): + return self.input_dim + + def get_output_dim(self): + return self.input_dim + + def forward(self, input_a: torch.Tensor, input_b: torch.Tensor) -> torch.Tensor: + if input_a.size() != input_b.size(): + raise ValueError("The input must have the same size.") + if input_a.size(-1) != self.input_dim: + raise ValueError("Input size must match `input_dim`.") + gate_value = self._activation(self._gate(torch.cat([input_a, input_b], -1))) + return gate_value * input_a + (1 - gate_value) * input_b \ No newline at end of file diff --git a/HybridViT/module/component/common/mae_posembed.py b/HybridViT/module/component/common/mae_posembed.py new file mode 100644 index 0000000000000000000000000000000000000000..187ecd981c59df11794df9d5be8c02d8a37e04d9 --- /dev/null +++ b/HybridViT/module/component/common/mae_posembed.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- + +def get_2d_sincos_pos_embed(embed_dim, grid_size_H, grid_size_W, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size_H, dtype=np.float32) + grid_w = np.arange(grid_size_W, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size_H, grid_size_W]) + + print('new grid shape', grid.shape) + + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + +if __name__ == '__main__': + pos_embed = get_2d_sincos_pos_embed(256, 800, 800, True) + print(pos_embed.shape) \ No newline at end of file diff --git a/HybridViT/module/component/common/maxout.py b/HybridViT/module/component/common/maxout.py new file mode 100644 index 0000000000000000000000000000000000000000..07e40b8fb96dfe09dd9ecdd04eb5094a84e2c20b --- /dev/null +++ b/HybridViT/module/component/common/maxout.py @@ -0,0 +1,22 @@ +from torch import nn + +class Maxout(nn.Module): + """ + Maxout makes pools from the last dimension and keeps only the maximum value from + each pool. + """ + + def __init__(self, pool_size): + """ + Args: + pool_size (int): Number of elements per pool + """ + super(Maxout, self).__init__() + self.pool_size = pool_size + + def forward(self, x): + [*shape, last] = x.size() + out = x.view(*shape, last // self.pool_size, self.pool_size) + out, _ = out.max(-1) + return out + diff --git a/HybridViT/module/component/common/postional_encoding.py b/HybridViT/module/component/common/postional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..3ced84f7bf616a8b3deae2650fba02b9adfff2d3 --- /dev/null +++ b/HybridViT/module/component/common/postional_encoding.py @@ -0,0 +1,226 @@ +import math + +import numpy as np +import torch +import torch.nn as nn +from typing import Tuple +from torch import Tensor +from torch.nn import functional as F + + +class Adaptive2DPositionalEncoding(nn.Module): + """Implement Adaptive 2D positional encoder for SATRN, see + `SATRN `_ + Modified from https://github.com/Media-Smart/vedastr + Licensed under the Apache License, Version 2.0 (the "License"); + Args: + d_hid (int): Dimensions of hidden layer. + n_height (int): Max height of the 2D feature output. + n_width (int): Max width of the 2D feature output. + dropout (int): Size of hidden layers of the model. + """ + + def __init__(self, + d_hid=512, + n_height=100, + n_width=100, + dropout=0.1, + ): + super().__init__() + + h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid) + h_position_encoder = h_position_encoder.transpose(0, 1) + h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1) + + w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid) + w_position_encoder = w_position_encoder.transpose(0, 1) + w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width) + + self.register_buffer('h_position_encoder', h_position_encoder) + self.register_buffer('w_position_encoder', w_position_encoder) + + self.h_scale = self.scale_factor_generate(d_hid) + self.w_scale = self.scale_factor_generate(d_hid) + self.pool = nn.AdaptiveAvgPool2d(1) + self.dropout = nn.Dropout(p=dropout) + + def _get_sinusoid_encoding_table(self, n_position, d_hid): + """Sinusoid position encoding table.""" + denominator = torch.Tensor([ + 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ]) + denominator = denominator.view(1, -1) + pos_tensor = torch.arange(n_position).unsqueeze(-1).float() + sinusoid_table = pos_tensor * denominator + sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) + + return sinusoid_table + + def scale_factor_generate(self, d_hid): + scale_factor = nn.Sequential( + nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True), + nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid()) + + return scale_factor + + def init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('ReLU')) + + def forward(self, x): + b, c, h, w = x.size() + + avg_pool = self.pool(x) + + h_pos_encoding = \ + self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :] + w_pos_encoding = \ + self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w] + + out = x + h_pos_encoding + w_pos_encoding + + out = self.dropout(out) + + return out + +class PositionalEncoding2D(nn.Module): + """2-D positional encodings for the feature maps produced by the encoder. + Following https://arxiv.org/abs/2103.06450 by Sumeet Singh. + Reference: + https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2021-labs/blob/main/lab9/text_recognizer/models/transformer_util.py + """ + + def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000) -> None: + super().__init__() + self.d_model = d_model + assert d_model % 2 == 0, f"Embedding depth {d_model} is not even" + pe = self.make_pe(d_model, max_h, max_w) # (d_model, max_h, max_w) + self.register_buffer("pe", pe) + + @staticmethod + def make_pe(d_model: int, max_h: int, max_w: int) -> Tensor: + """Compute positional encoding.""" + pe_h = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2) + pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w) + + pe_w = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2) + pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w) + + pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w) + return pe + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + Args: + x: (B, d_model, H, W) + Returns: + (B, d_model, H, W) + """ + assert x.shape[1] == self.pe.shape[0] # type: ignore + x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore + return x + + +class PositionalEncoding1D(nn.Module): + """Classic Attention-is-all-you-need positional encoding.""" + + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000) -> None: + super().__init__() + self.dropout = nn.Dropout(p=dropout) + pe = self.make_pe(d_model, max_len) # (max_len, 1, d_model) + self.register_buffer("pe", pe) + + @staticmethod + def make_pe(d_model: int, max_len: int) -> Tensor: + """Compute positional encoding.""" + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(1) + return pe + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + Args: + x: (S, B, d_model) + Returns: + (S, B, d_model) + """ + assert x.shape[2] == self.pe.shape[2] # type: ignore + x = x + self.pe[: x.size(0)] # type: ignore + return self.dropout(x) + +Size_ = Tuple[int, int] + +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), ) + self.stride = stride + + def forward(self, x, size: Size_): + B, N, C = x.shape + cls_token, feat_token = x[:, 0], x[:, 1:] + cnn_feat_token = feat_token.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) + x = torch.cat((cls_token.unsqueeze(1), x), dim=1) + return x + + def no_weight_decay(self): + return ['proj.%d.weight' % i for i in range(4)] + +class PosConv1D(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv1D, self).__init__() + self.proj = nn.Sequential(nn.Conv1d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), ) + self.stride = stride + + def forward(self, x, size: int): + B, N, C = x.shape + cls_token, feat_token = x[:, 0], x[:, 1:] + cnn_feat_token = feat_token.transpose(1, 2).view(B, C, size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.transpose(1, 2) + x = torch.cat((cls_token.unsqueeze(1), x), dim=1) + return x + + def no_weight_decay(self): + return ['proj.%d.weight' % i for i in range(4)] + +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), old_grid_shape=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + + print('Resized position embedding: %s to %s'%(posemb.shape, posemb_new.shape)) + ntok_new = posemb_new.shape[1] + + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + + assert len(gs_new) >= 2 + + print('Position embedding grid-size from %s to %s'%(old_grid_shape, gs_new)) + posemb_grid = posemb_grid.reshape(1, old_grid_shape[0], old_grid_shape[1], -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb diff --git a/HybridViT/module/component/feature_extractor/__init__.py b/HybridViT/module/component/feature_extractor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/HybridViT/module/component/feature_extractor/addon_module/__init__.py b/HybridViT/module/component/feature_extractor/addon_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d5809d2e26a7c602ccdbb5d41379a5e15d395009 --- /dev/null +++ b/HybridViT/module/component/feature_extractor/addon_module/__init__.py @@ -0,0 +1,2 @@ +from .aspp import * +from .visual_attention import * \ No newline at end of file diff --git a/HybridViT/module/component/feature_extractor/addon_module/aspp.py b/HybridViT/module/component/feature_extractor/addon_module/aspp.py new file mode 100644 index 0000000000000000000000000000000000000000..d3d333e8c4064a8f31007ad6b0d125e3c277bfad --- /dev/null +++ b/HybridViT/module/component/feature_extractor/addon_module/aspp.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['ASPP'] + +class ASPPModule(nn.Module): + def __init__(self, inplanes, planes, kernel_size, padding, dilation): + super(ASPPModule, self).__init__() + self.atrous_conv = nn.Conv2d( + inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False + ) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): # skipcq: PYL-W0221 + x = self.atrous_conv(x) + x = self.relu(x) + return x + +class ASPP(nn.Module): + def __init__(self, inplanes: int, output_stride: int, output_features: int, dropout=0.5): + super(ASPP, self).__init__() + + if output_stride == 32: + dilations = [1, 3, 6, 9] + elif output_stride == 16: + dilations = [1, 6, 12, 18] + elif output_stride == 8: + dilations = [1, 12, 24, 36] + else: + raise NotImplementedError + + self.aspp1 = ASPPModule(inplanes, output_features, 1, padding=0, dilation=dilations[0]) + self.aspp2 = ASPPModule(inplanes, output_features, 3, padding=dilations[1], dilation=dilations[1]) + self.aspp3 = ASPPModule(inplanes, output_features, 3, padding=dilations[2], dilation=dilations[2]) + self.aspp4 = ASPPModule(inplanes, output_features, 3, padding=dilations[3], dilation=dilations[3]) + + self.global_avg_pool = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(inplanes, output_features, 1, stride=1, bias=False), + nn.ReLU(inplace=True), + ) + self.conv1 = nn.Conv2d(output_features * 5, output_features, 1, bias=False) + self.relu1 = nn.ReLU(inplace=True) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): # skipcq: PYL-W0221 + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=False) + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.conv1(x) + x = self.relu1(x) + + return self.dropout(x) \ No newline at end of file diff --git a/HybridViT/module/component/feature_extractor/addon_module/visual_attention.py b/HybridViT/module/component/feature_extractor/addon_module/visual_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..16e4ed2c0e294a4794c75d53a85c2523de40901d --- /dev/null +++ b/HybridViT/module/component/feature_extractor/addon_module/visual_attention.py @@ -0,0 +1,325 @@ +import torch.nn as nn +import torch +from torch.nn import functional as F +from functools import reduce + +__all__ = ['Adaptive_Global_Model', 'GlobalContext', 'SELayer', 'SKBlock', 'CBAM'] + + +def constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def xavier_init(module, gain=1, bias=0, distribution='normal'): + assert distribution in ['uniform', 'normal'] + if distribution == 'uniform': + nn.init.xavier_uniform_(module.weight, gain=gain) + else: + nn.init.xavier_normal_(module.weight, gain=gain) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def normal_init(module, mean=0, std=1, bias=0): + nn.init.normal_(module.weight, mean, std) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def uniform_init(module, a=0, b=1, bias=0): + nn.init.uniform_(module.weight, a, b) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def kaiming_init(module, + a=0, + mode='fan_out', + nonlinearity='relu', + bias=0, + distribution='normal'): + assert distribution in ['uniform', 'normal'] + if distribution == 'uniform': + nn.init.kaiming_uniform_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + else: + nn.init.kaiming_normal_( + module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def last_zero_init(m): + if isinstance(m, nn.Sequential): + constant_init(m[-1], val=0) + else: + constant_init(m, val=0) + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=16, dropout=0.1): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(channel // reduction, channel), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x + x * y.expand_as(x) # SE-Residual + + +class BasicConv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, + bn=True, bias=False): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None + self.relu = nn.ReLU() if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.maxpool = nn.AdaptiveMaxPool2d(1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type == 'avg': + avg_pool = self.avgpool(x) + channel_att_raw = self.mlp(avg_pool) + elif pool_type == 'max': + max_pool = self.maxpool(x) + channel_att_raw = self.mlp(max_pool) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + scale = self.sigmoid(channel_att_sum).unsqueeze(2) \ + .unsqueeze(3).expand_as(x) + return x * scale + + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1) \ + .unsqueeze(1)), dim=1) + + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv(2, 1, kernel_size, stride=1 \ + , padding=(kernel_size - 1) // 2, relu=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = self.sigmoid(x_out) # broadcasting + return x * scale + + +class CBAM(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): + super(CBAM, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial = no_spatial + if not no_spatial: + self.SpatialGate = SpatialGate() + + def forward(self, x): + x_out = self.ChannelGate(x) + if not self.no_spatial: + x_out = self.SpatialGate(x_out) + return x_out + + +class SKBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, M=2, r=16, L=32): + super(SKBlock, self).__init__() + d = max(in_channels // r, L) + self.M = M + self.out_channels = out_channels + self.conv = nn.ModuleList() + for i in range(M): + self.conv.append(nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, stride, padding=1 + i, dilation=1 + i, groups=32, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True))) + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Sequential(nn.Conv2d(out_channels, d, 1, bias=False), + nn.BatchNorm2d(d), + nn.ReLU(inplace=True)) + self.fc2 = nn.Conv2d(d, out_channels * M, 1, 1, bias=False) + self.softmax = nn.Softmax(dim=1) + + def forward(self, input): + batch_size = input.size(0) + output = [] + # the part of split + for i, conv in enumerate(self.conv): + # print(i,conv(input).size()) + output.append(conv(input)) + # the part of fusion + U = reduce(lambda x, y: x + y, output) + s = self.global_pool(U) + z = self.fc1(s) + a_b = self.fc2(z) + a_b = a_b.reshape(batch_size, self.M, self.out_channels, -1) + a_b = self.softmax(a_b) + # the part of selection + a_b = list(a_b.chunk(self.M, dim=1)) # split to a and b + a_b = list(map(lambda x: x.reshape(batch_size, self.out_channels, 1, 1), a_b)) + V = list(map(lambda x, y: x * y, output, a_b)) + V = reduce(lambda x, y: x + y, V) + return V + +def make_divisible(v, divisor=8, min_value=None, round_limit=.9): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < round_limit * v: + new_v += divisor + return new_v + +class LayerNorm2d(nn.LayerNorm): + """ LayerNorm for channels of '2D' spatial BCHW tensors """ + def __init__(self, num_channels): + super().__init__(num_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + +class ConvMLP(nn.Module): + def __init__(self, in_channels, out_channels=None, hidden_channels=None, drop=0.25): + super().__init__() + out_channels = in_channels or out_channels + hidden_channels = in_channels or hidden_channels + self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=True) + self.norm = LayerNorm2d(hidden_channels) + self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=True) + self.act = nn.ReLU() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x + +class GlobalContext(nn.Module): + def __init__(self, + channel, + use_attn=True, + fuse_add=True, + fuse_scale=False, + rd_ratio=1./8, + rd_channels=None + ): + super().__init__() + self.use_attn = use_attn + self.global_cxt = nn.Conv2d(channel, 1, kernel_size=1, bias=True) if use_attn else nn.AdaptiveAvgPool2d(1) + + if rd_channels is None: + rd_channels = make_divisible(channel*rd_ratio, divisor=1, round_limit=0.) + + if fuse_add: + self.bottleneck_add = ConvMLP(channel, hidden_channels=rd_channels) + else: + self.bottleneck_add = None + if fuse_scale: + self.bottleneck_mul = ConvMLP(channel, hidden_channels=rd_channels) + else: + self.bottleneck_mul = None + + self.init_weight() + + def init_weight(self): + if self.use_attn: + nn.init.kaiming_normal_(self.global_cxt.weight, mode='fan_in', nonlinearity='relu') + if self.bottleneck_add is not None: + nn.init.zeros_(self.bottleneck_add.fc2.weight) + if self.bottleneck_mul is not None: + nn.init.zeros_(self.bottleneck_mul.fc2.weight) + + def forward(self, x): + B, C, H, W = x.shape + if self.use_attn: + attn = self.global_cxt(x).reshape(B, 1, H*W).squeeze(1) + attn = F.softmax(attn, dim=-1).unsqueeze(-1) #shape BxH*Wx1 + query = x.reshape(B, C, H*W) #shape BxCxH*W + glob_cxt = torch.bmm(query, attn).unsqueeze(-1) + else: + glob_cxt = self.global_cxt(x) + assert len(glob_cxt.shape) == 4 + + if self.bottleneck_add is not None: + x_trans = self.bottleneck_add(glob_cxt) + x_fuse = x + x_trans + if self.bottleneck_mul is not None: + x_trans = F.sigmoid(self.bottleneck_mul(glob_cxt)) + x_fuse = x*x_trans + + return x_fuse + + +class Adaptive_Global_Model(nn.Module): + def __init__(self, inplanes, factor=2, ratio=0.0625, dropout=0.1): + super(Adaptive_Global_Model, self).__init__() + # b, w, h, c => gc_block (b, w, h, c) => => b, w, inplanes + self.embedding = nn.Linear(inplanes * factor, inplanes) + self.gc_block = GlobalContext(inplanes, ratio=ratio) # + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.gc_block(x) # BCHW => BCHW + x = x.permute(0, 3, 1, 2) # BCHW => BWCH + b, w, _, _ = x.shape + x = x.contiguous().view(b, w, -1) + x = self.embedding(x) # B W C + x = self.dropout(x) + return x + diff --git a/HybridViT/module/component/feature_extractor/clova_impl/__init__.py b/HybridViT/module/component/feature_extractor/clova_impl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18f311768a9ed4ca490cac0d96b9adcccd4946ac --- /dev/null +++ b/HybridViT/module/component/feature_extractor/clova_impl/__init__.py @@ -0,0 +1,2 @@ +from .resnet import ResNet_FeatureExtractor +from .vgg import VGG_FeatureExtractor \ No newline at end of file diff --git a/HybridViT/module/component/feature_extractor/clova_impl/resnet.py b/HybridViT/module/component/feature_extractor/clova_impl/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e32faa3bef418369ada99770c3bbc938fb1c0d8a --- /dev/null +++ b/HybridViT/module/component/feature_extractor/clova_impl/resnet.py @@ -0,0 +1,262 @@ +from typing import Dict +from collections import OrderedDict +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..addon_module.visual_attention import GlobalContext +from .....helper import clean_state_dict + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = self._conv3x3(inplanes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = self._conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def _conv3x3(self, in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + +class ResNet(nn.Module): + def __init__(self, input_channel, output_channel, block, layers, with_gcb=True, debug=False, zero_init_last_bn=False): + super(ResNet, self).__init__() + self.with_gcb = with_gcb + + self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] + self.inplanes = int(output_channel / 8) + + self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) + + self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, + kernel_size=3, stride=1, padding=1, bias=False) + self.bn0_2 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + + self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) + self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ + 0], kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) + + self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) + self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ + 1], kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) + + self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) + self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) + self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ + 2], kernel_size=3, stride=1, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) + + self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) + + self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) + self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) + + self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ + 3], kernel_size=2, stride=1, padding=0, bias=False) + self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) + + self.init_weights(zero_init_last_bn=zero_init_last_bn) + self.debug = debug + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn4_2.weight) + + def init_weights(self, zero_init_last_bn=True): + initialized = ['global_cxt', 'bottleneck_add', 'bottleneck_mul'] + for n, m in self.named_modules(): + if any([d in n for d in initialized]): + continue + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def _make_layer(self, block, planes, blocks, with_gcb=False, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + if self.with_gcb: + layers.append(GlobalContext(planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + if self.debug: + print('input shape', x.shape) + + x = self.conv0_1(x) + x = self.bn0_1(x) + x = self.relu(x) + + if self.debug: + print('conv1 shape', x.shape) + + x = self.conv0_2(x) + x = self.bn0_2(x) + x = self.relu(x) + + if self.debug: + print('conv2 shape', x.shape) + + x = self.maxpool1(x) + + if self.debug: + print('pool1 shape', x.shape) + + x = self.layer1(x) + + if self.debug: + print('block1 shape', x.shape) + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + if self.debug: + print('conv3 shape', x.shape) + + x = self.maxpool2(x) + + if self.debug: + print('pool2 shape', x.shape) + + x = self.layer2(x) + + if self.debug: + print('block2 shape', x.shape) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + if self.debug: + print('conv4 shape', x.shape) + + x = self.maxpool3(x) + + if self.debug: + print('pool3 shape', x.shape) + + x = self.layer3(x) + + if self.debug: + print('block3 shape', x.shape) + + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + if self.debug: + print('conv5 shape', x.shape) + + x = self.layer4(x) + + if self.debug: + print('block4 shape', x.shape) + + x = self.conv4_1(x) + x = self.bn4_1(x) + x = self.relu(x) + + if self.debug: + print('conv6 shape', x.shape) + + x = self.conv4_2(x) + x = self.bn4_2(x) + x = self.relu(x) + + if self.debug: + print('conv7 shape', x.shape) + + return x + +class ResNet_FeatureExtractor(nn.Module): + """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ + + def __init__(self, input_channel=3, output_channel=512, gcb=False, pretrained=False, weight_dir=None, debug=False): + super(ResNet_FeatureExtractor, self).__init__() + self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3], gcb, debug) + self.in_chans = input_channel + if pretrained: + assert weight_dir is not None + self.load_pretrained(weight_dir) + + def forward(self, input): + output = self.ConvNet(input) + return output + + def load_pretrained(self, weight_dir): + state_dict: OrderedDict = torch.load(weight_dir) + cleaned_state_dict = clean_state_dict(state_dict) + new_state_dict = OrderedDict() + name: str + param: torch.FloatTensor + for name, param in cleaned_state_dict.items(): + if name.startswith('FeatureExtraction'): + output_name = name.replace('FeatureExtraction.', '') + if output_name == 'ConvNet.conv0_1.weight': + print('Old', param.shape) + new_param = param.repeat(1, self.in_chans, 1, 1) + print('New', new_param.shape) + else: new_param = param + new_state_dict[output_name] = new_param + print("=> Loading pretrained weight for ResNet backbone") + self.load_state_dict(new_state_dict) + +if __name__ == '__main__': + model = ResNet_FeatureExtractor(input_channel=1, debug=True) + a = torch.rand(1, 1, 128, 480) + output = model(a) + print(output.shape) \ No newline at end of file diff --git a/HybridViT/module/component/feature_extractor/clova_impl/vgg.py b/HybridViT/module/component/feature_extractor/clova_impl/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..beb924bd0d407a0e181fdfb5c9d2231a6dddfa48 --- /dev/null +++ b/HybridViT/module/component/feature_extractor/clova_impl/vgg.py @@ -0,0 +1,27 @@ +import torch.nn as nn +import torch.nn.functional as F + +class VGG_FeatureExtractor(nn.Module): + """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ + + def __init__(self, input_channel, output_channel=512): + super(VGG_FeatureExtractor, self).__init__() + self.output_channel = [int(output_channel / 8), int(output_channel / 4), + int(output_channel / 2), output_channel] # [64, 128, 256, 512] + self.ConvNet = nn.Sequential( + nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 64x16x50 + nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d(2, 2), # 128x8x25 + nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25 + nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), + nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 + nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 + nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), + nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), + nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 + nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 + + def forward(self, input): + return self.ConvNet(input) diff --git a/HybridViT/module/component/feature_extractor/helpers.py b/HybridViT/module/component/feature_extractor/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..f23da247054997e2d37899ce4133346b68d53933 --- /dev/null +++ b/HybridViT/module/component/feature_extractor/helpers.py @@ -0,0 +1,76 @@ +import math +from typing import List, Tuple + +import torch.nn.functional as F + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution +def get_same_padding(x: int, k: int, s: int, d: int): + return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +# Can SAME padding for given args be done statically? +def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +# Dynamically pad input x with 'SAME' padding for conv with specified args +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): + ih, iw = x.size()[-2:] + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) + return x + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight \ No newline at end of file diff --git a/HybridViT/module/component/feature_extractor/vgg.py b/HybridViT/module/component/feature_extractor/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..980dac3ad05dc48656e0eaac54606f1be7efd445 --- /dev/null +++ b/HybridViT/module/component/feature_extractor/vgg.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn + +__all__ = ['vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn'] + + +class VGG(nn.Module): + + def __init__(self, features, num_channel_out=512, init_weights=True): + super(VGG, self).__init__() + self.features = features + self.num_out_features = 512 + + self.lastlayer = nn.Sequential( + nn.Conv2d(self.num_out_features, num_channel_out, kernel_size=1, stride=1, padding=0, groups=32, bias=False), + nn.BatchNorm2d(num_channel_out), + nn.ReLU(inplace=True), + ) + + if init_weights: + self._initialize_weights() + + def forward(self, x): + x = self.features(x) + x = self.lastlayer(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def make_layers(cfg, down_sample=8, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + elif isinstance(v, dict): + cur_size = v[down_sample] + layers += [nn.MaxPool2d(kernel_size=cur_size, stride=cur_size)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +cfgs = { + 'A': [64, 'M', 128, 'M', 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, {4: (2, 1), 8: (2, 1)}, 512, 512, + {4: (2, 1), 8: (2, 1)}], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, {4: (2, 1), 8: (2, 1)}, 512, 512, + {4: (2, 1), 8: (2, 1)}, ], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, 512, {4: (2, 1), 8: (2, 1)}, 512, + 512, 512, {4: (2, 1), 8: (2, 1)}, ], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, 512, 512, + {4: (2, 1), 8: (2, 1)}, 512, 512, 512, 512, {4: (2, 1), 8: (2, 1)}, ], +} + + +def _vgg(model_path, cfg, batch_norm, pretrained, progress, num_channel_out, down_sample, **kwargs): + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfgs[cfg], down_sample, batch_norm=batch_norm), num_channel_out, **kwargs) + if model_path and pretrained: + state_dict = torch.load(model_path) + model.load_state_dict(state_dict, strict=False) + return model + + +def vgg11_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs): + return _vgg(model_path, 'A', True, pretrained, progress, num_channel_out, down_sample, **kwargs) + + +def vgg13_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs): + return _vgg(model_path, 'B', True, pretrained, progress, num_channel_out, down_sample, **kwargs) + + +def vgg16_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs): + return _vgg(model_path, 'D', True, pretrained, progress, num_channel_out, down_sample, **kwargs) + + +def vgg19_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs): + return _vgg(model_path, 'E', True, pretrained, progress, num_channel_out, down_sample, **kwargs) \ No newline at end of file diff --git a/HybridViT/module/component/prediction_head/__init__.py b/HybridViT/module/component/prediction_head/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..48614426ef60537fb63eaf08cc105c003f79aea1 --- /dev/null +++ b/HybridViT/module/component/prediction_head/__init__.py @@ -0,0 +1,5 @@ +from .seq2seq import Attention +from .seq2seq_v2 import AttentionV2 +from .tfm import TransformerPrediction + +__all__ = ['Attention', 'AttentionV2', 'TransformerPrediction'] \ No newline at end of file diff --git a/HybridViT/module/component/prediction_head/addon_module/__init__.py b/HybridViT/module/component/prediction_head/addon_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37e749bbfd6d410ecac27a1f772961e572ed2ec1 --- /dev/null +++ b/HybridViT/module/component/prediction_head/addon_module/__init__.py @@ -0,0 +1,3 @@ +from .attention1D import * +from .attention2D import * +from .position_encoding import * \ No newline at end of file diff --git a/HybridViT/module/component/prediction_head/addon_module/attention1D.py b/HybridViT/module/component/prediction_head/addon_module/attention1D.py new file mode 100644 index 0000000000000000000000000000000000000000..df74822445875c572604f0c62060a7e601b5b4a4 --- /dev/null +++ b/HybridViT/module/component/prediction_head/addon_module/attention1D.py @@ -0,0 +1,218 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +torch.autograd.set_detect_anomaly(True) + +class LuongAttention(nn.Module): + def __init__(self, input_size, hidden_size, num_embeddings, num_classes, method='dot'): + super(LuongAttention, self).__init__() + self.attn = LuongAttentionCell(hidden_size, method) + self.rnn = nn.LSTMCell(num_embeddings, hidden_size) + self.hidden_size = hidden_size + self.generator = nn.Linear(2*hidden_size, num_classes) + + def forward(self, prev_hidden, batch_H, embed_text): + hidden = self.rnn(embed_text, prev_hidden) + + e = self.attn(hidden[0], batch_H) + # print('Shape e', e.shape) + alpha = F.softmax(e, dim=1) + # print('Shape al', alpha.shape) + + context = torch.bmm(alpha.unsqueeze(1), batch_H).squeeze(1) # batch_size x num_channel + output = torch.cat([context, hidden[0]], 1) # batch_size x (num_channel + num_embedding) + output = torch.tanh(output) + output = self.generator(output) + + return output, hidden, alpha + +class LuongAttentionCell(nn.Module): + def __init__(self, hidden_size, method='dot'): + super(LuongAttentionCell, self).__init__() + self.method = method + self.hidden_size = hidden_size + + # Defining the layers/weights required depending on alignment scoring method + if method == "general": + self.fc = nn.Linear(hidden_size, hidden_size, bias=False) + + elif method == "concat": + self.fc = nn.Linear(hidden_size, hidden_size, bias=False) + self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size)) + + def forward(self, decoder_hidden, encoder_outputs): + decoder_hidden = decoder_hidden.unsqueeze(1) + # print('shape', decoder_hidden.shape) + + if self.method == "dot": + # For the dot scoring method, no weights or linear layers are involved + return encoder_outputs.bmm(decoder_hidden.permute(0, 2, 1)).squeeze(-1) + + elif self.method == "general": + # For general scoring, decoder hidden state is passed through linear layers to introduce a weight matrix + out = self.fc(decoder_hidden) + return encoder_outputs.bmm(out.permute(0, 2 , 1)).squeeze(-1) + + elif self.method == "concat": + # For concat scoring, decoder hidden state and encoder outputs are concatenated first + out = torch.tanh(self.fc(decoder_hidden+encoder_outputs)) + # print('Shape', out.shape) + return out.bmm(self.weight.unsqueeze(-1).repeat(out.shape[0], 1, 1)).squeeze(-1) + +class BahdanauAttentionCell(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(BahdanauAttentionCell, self).__init__() + self.i2h = nn.Linear(input_dim, hidden_dim, bias=False) + self.h2h = nn.Linear(hidden_dim, hidden_dim) + self.score = nn.Linear(hidden_dim, 1, bias=False) + + def forward(self, decoder_hidden, encoder_output): + encoder_proj = self.i2h(encoder_output) + hidden_proj = self.h2h(decoder_hidden[0]).unsqueeze(1) + score = self.score(torch.tanh(encoder_proj + hidden_proj)) + return score + +class BahdanauAttention(nn.Module): + def __init__(self, input_size=100, hidden_size=256, num_embeddings=10, num_classes=10): + super(BahdanauAttention, self).__init__() + self.attn = BahdanauAttentionCell(input_size, hidden_size) + self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) + self.input_size = input_size + self.hidden_size = hidden_size + self.generator = nn.Linear(hidden_size, num_classes) + + def set_mem(self, prev_attn): + pass + + def reset_mem(self): + pass + + def forward(self, prev_hidden, batch_H, embed_text): + # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] + e = self.attn(prev_hidden, batch_H) + alpha = F.softmax(e, dim=1) + context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel + concat_context = torch.cat([context, embed_text], 1) # batch_size x (num_channel + num_embedding) + cur_hidden = self.rnn(concat_context, prev_hidden) + output = self.generator(cur_hidden[0]) + + return output, cur_hidden, alpha + +class ConstituentAttentionCell(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + +class ConstituentCoverageAttentionCell(ConstituentAttentionCell): + pass + +class LocationAwareAttentionCell(nn.Module): + def __init__(self, kernel_size, kernel_dim, hidden_dim, input_dim): + super().__init__() + self.loc_conv = nn.Conv1d(1, kernel_dim, kernel_size=2*kernel_size+1, padding=kernel_size, bias=True) + self.loc_proj = nn.Linear(kernel_dim, hidden_dim) + self.query_proj = nn.Linear(hidden_dim, hidden_dim) + self.key_proj = nn.Linear(input_dim, hidden_dim) + self.score = nn.Linear(hidden_dim, 1) + + def forward(self, decoder_hidden, encoder_output, last_alignment): + batch_size, seq_length, hidden_dim = encoder_output.shape[0], encoder_output.shape[1], decoder_hidden[0].shape[1] + + encoder_proj = self.key_proj(encoder_output) + hidden_proj = self.query_proj(decoder_hidden[0]).unsqueeze(1) + + if last_alignment is None: + last_alignment = decoder_hidden[0].new_zeros(batch_size, seq_length, 1) + + loc_context = self.loc_conv(last_alignment.permute(0, 2, 1)) + loc_context = loc_context.transpose(1, 2) + loc_context = self.loc_proj(loc_context) + + assert len(loc_context.shape) == 3 + assert loc_context.shape[0] == batch_size, f'{loc_context.shape[0]}-{batch_size}' + assert loc_context.shape[1] == seq_length + assert loc_context.shape[2] == hidden_dim + + score = self.score(torch.tanh( + encoder_proj + + hidden_proj + + loc_context + )) + return score + +class CoverageAttention(nn.Module): + def __init__(self, input_dim, hidden_dim, kernel_size, kernel_dim, temperature=1.0, smoothing=False): + super().__init__() + self.smoothing = smoothing + self.temperature = temperature + self.prev_attn = None + self.attn = LocationAwareAttentionCell(kernel_size, kernel_dim, hidden_dim, input_dim) + + def set_mem(self, prev_attn): + self.prev_attn = prev_attn + + def reset_mem(self): + self.prev_attn = None + + def forward(self, prev_hidden, batch_H): + e = self.attn(prev_hidden, batch_H, self.prev_attn) + + if self.smoothing: + e = F.sigmoid(e, dim=1) + alpha = e.div(e.sum(dim=-1).unsqueeze(dim=-1)) + else: + e = e / self.temperature + alpha = F.softmax(e, dim=1) + + context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel + + return context, alpha + +class LocationAwareAttention(BahdanauAttention): + def __init__(self, kernel_size, kernel_dim, temperature=1.0, smoothing=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.smoothing = smoothing + self.temperature = temperature + self.prev_attn = None + self.attn = LocationAwareAttentionCell(kernel_size, kernel_dim, self.hidden_size, self.input_size) + + def set_mem(self, prev_attn): + self.prev_attn = prev_attn + + def reset_mem(self): + self.prev_attn = None + + def forward(self, prev_hidden, batch_H, embed_text): + e = self.attn(prev_hidden, batch_H, self.prev_attn) + + if self.smoothing: + e = F.sigmoid(e, dim=1) + alpha = e.div(e.sum(dim=-1).unsqueeze(dim=-1)) + else: + e = e / self.temperature + alpha = F.softmax(e, dim=1) + + context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel , batch_H: batch_sizexseq_lengthxnum_channel, alpha: + concat_context = torch.cat([context, embed_text], 1) # batch_size x (num_channel + num_embedding) + cur_hidden = self.rnn(concat_context, prev_hidden) + output = self.generator(cur_hidden[0]) + + return output, cur_hidden, alpha + + +# class MaskAttention(nn.Module): +# def __init__(self): +# super().__init__() + +# class CoverageAttentionCell(nn.Module): +# def __init__(self, ) + +# class CoverageAttention(nn.Module): +# """ +# http://home.ustc.edu.cn/~xysszjs/paper/PR2017.pdf +# """ +# def __init__(self, input_size, hidden_size, num_embedding): +# super().__init__() + +# def forward(self, prev_hidden, batch_H, embed_text): diff --git a/HybridViT/module/component/prediction_head/addon_module/attention2D.py b/HybridViT/module/component/prediction_head/addon_module/attention2D.py new file mode 100644 index 0000000000000000000000000000000000000000..0f5c4c005eb1bde8cee73516d61ff5a87bcdbbce --- /dev/null +++ b/HybridViT/module/component/prediction_head/addon_module/attention2D.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +from .attention1D import LocationAwareAttention +""" +NOTE : +""" +class SARAttention(nn.Module): + def __init__(self, + input_size, + attention_size, + backbone_size, + output_size, + + ): + self.conv1x1_1 = nn.Conv2d(output_size, attention_size, kernel_size=1, stride=1) + self.conv3x3 = nn.Conv2d(backbone_size, attention_size, kernel_size=3, stride=1, padding=1) + self.conv1x1_2 = nn.Conv2d(attention_size, 1, kernel_size=1, stride=1) + + self.rnn_decoder_1 = nn.LSTMCell(input_size, input_size) + self.rnn_decoder_2 = nn.LSTMCell(input_size, input_size) + + + def forward( + self, + dec_input, + feature_map, + holistic_feature, + hidden_1, + cell_1, + hidden_2, + cell_2 + ): + _, _, H_feat, W_feat = feature_map.size() + hidden_1, cell_1 = self.rnn_decoder_1(dec_input, (hidden_1, cell_1)) + hidden_2, cell_2 = self.rnn_decoder_2(hidden_1, (hidden_2, cell_2)) + + hidden_2_tile = hidden_2.view(hidden_2.size(0), hidden_2.size(1), 1, 1) + attn_query = self.conv1x1_1(hidden_2_tile) + attn_query = attn_query.expand(-1, -1, H_feat, W_feat) + + attn_key = self.conv3x3(feature_map) + attn_weight = torch.tanh(torch.add(attn_query, attn_key, alpha=1)) + attn_weight = self.conv1x1_2(attn_weight) #shape B, 1, H, W + + #TO DO: apply mask for attention weight + + +class LocationAwareAttentionCell2D(nn.Module): + def __init__(self, kernel_size, kernel_dim, hidden_dim, input_dim): + super().__init__() + self.loc_conv = nn.Conv2d(1, kernel_dim, kernel_size=2*kernel_size+1, padding=kernel_size, bias=True) + self.loc_proj = nn.Linear(kernel_dim, hidden_dim) + self.query_proj = nn.Linear(hidden_dim, hidden_dim) + self.key_proj = nn.Linear(input_dim, hidden_dim) + self.score = nn.Linear(hidden_dim, 1) + + def forward(self, decoder_hidden, encoder_output, last_alignment): + batch_size, enc_h, enc_w, hidden_dim = encoder_output.shape[0], encoder_output.shape[1], encoder_output.shape[2], decoder_hidden[0].shape[1] + + encoder_proj = self.key_proj(encoder_output) + hidden_proj = self.query_proj(decoder_hidden[0]).unsqueeze(1) + + if last_alignment is None: + last_alignment = decoder_hidden[0].new_zeros(batch_size, enc_h, enc_w, 1) + + loc_context = self.loc_conv(last_alignment.permute(0, 2, 1)) + loc_context = loc_context.transpose(1, 2) + loc_context = self.loc_proj(loc_context) + + assert len(loc_context.shape) == 3 + assert loc_context.shape[0] == batch_size, f'{loc_context.shape[0]}-{batch_size}' + assert loc_context.shape[1] == enc_h + assert loc_context.shape[2] == enc_w + assert loc_context.shape[3] == hidden_dim + + loc_context = loc_context.reshape(batch_size, enc_h*enc_w, hidden_dim) + + score = self.score(torch.tanh( + encoder_proj + + hidden_proj + + loc_context + )) + return score + +class LocationAwareAttention2D(LocationAwareAttention): + def __init__(self, kernel_size, kernel_dim, temperature=1, smoothing=False, *args, **kwargs): + super().__init__(kernel_size, kernel_dim, temperature, smoothing, *args, **kwargs) + self.attn = LocationAwareAttentionCell2D(kernel_size, kernel_dim, self.hidden_size, self.input_size) diff --git a/HybridViT/module/component/prediction_head/addon_module/position_encoding.py b/HybridViT/module/component/prediction_head/addon_module/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..cba525cc89544df367a1a784ffd94d9f941a068e --- /dev/null +++ b/HybridViT/module/component/prediction_head/addon_module/position_encoding.py @@ -0,0 +1,27 @@ +import torch +from torch import nn + +__all__ = ['WordPosEnc'] + +class WordPosEnc(nn.Module): + def __init__( + self, d_model: int = 512, max_len: int = 500, temperature: float = 10000.0 + ) -> None: + super().__init__() + pe = torch.zeros(max_len, d_model) + + position = torch.arange(0, max_len, dtype=torch.float) + dim_t = torch.arange(0, d_model, 2, dtype=torch.float) + div_term = 1.0 / (temperature ** (dim_t / d_model)) + + inv_freq = torch.einsum("i, j -> i j", position, div_term) + + pe[:, 0::2] = inv_freq.sin() + pe[:, 1::2] = inv_freq.cos() + self.register_buffer("pe", pe) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, seq_len, _ = x.size() + emb = self.pe[:seq_len, :] + x = x + emb[None, :, :] + return x diff --git a/HybridViT/module/component/prediction_head/seq2seq.py b/HybridViT/module/component/prediction_head/seq2seq.py new file mode 100644 index 0000000000000000000000000000000000000000..0920f5035bc50914b3414a6c78302be2b83fb1cd --- /dev/null +++ b/HybridViT/module/component/prediction_head/seq2seq.py @@ -0,0 +1,268 @@ +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat +from ...converter import AttnLabelConverter as ATTN +from .addon_module import * + +class Attention(nn.Module): + def __init__(self, + kernel_size, + kernel_dim, + input_size, + hidden_size, + num_classes, + embed_dim=None, + attn_type='coverage', + embed_target=False, + enc_init=False, #init hidden state of decoder with enc output + teacher_forcing=1.0, + droprate=0.1, + method='concat', + seqmodel='ViT', + viz_attn: bool = False, + device='cuda' + ): + super(Attention, self).__init__() + if embed_dim is None: embed_dim = input_size + if embed_target: + self.embedding = nn.Embedding(num_classes, embed_dim, padding_idx=ATTN.START()) + + common = { + 'input_size': input_size, + 'hidden_size': hidden_size, + 'num_embeddings': embed_dim if embed_target else num_classes, + 'num_classes': num_classes + } + + if attn_type == 'luong': + common['method'] = method + self.attention_cell = LuongAttention(**common) + elif attn_type == 'loc_aware': + self.attention_cell = LocationAwareAttention(kernel_size=kernel_size, kernel_dim=kernel_dim, **common) + elif attn_type == 'coverage': + self.attention_cell = LocationAwareAttention(kernel_size=kernel_size, kernel_dim=kernel_dim, **common) + else: + self.attention_cell = BahdanauAttention(**common) + + self.dropout = nn.Dropout(droprate) + self.embed_target = embed_target + self.hidden_size = hidden_size + self.input_size = input_size + self.num_classes = num_classes + self.teacher_forcing = teacher_forcing + self.device = device + self.attn_type = attn_type + self.enc_init = enc_init + self.viz_attn = viz_attn + self.seqmodel = seqmodel + + if enc_init: self.init_hidden() + + def _embed_text(self, input_char): + return self.embedding(input_char) + + def _char_to_onehot(self, input_char, onehot_dim=38): + input_char = input_char.unsqueeze(1) + batch_size = input_char.size(0) + one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(self.device) + one_hot = one_hot.scatter_(1, input_char, 1) + return one_hot + + def init_hidden(self): + self.proj_init_h = nn.Linear(self.input_size, self.hidden_size, bias=True) + self.proj_init_c = nn.Linear(self.input_size, self.hidden_size, bias=True) + + def forward_beam( + self, + batch_H: torch.Tensor, + batch_max_length=25, + beam_size=4, + ): + batch_size = batch_H.size(0) + assert batch_size == 1 + num_steps = batch_max_length + 1 + batch_H = batch_H.squeeze(dim=0) + batch_H = repeat(batch_H, "s e -> b s e", b = beam_size) + + if self.enc_init: + if self.seqmodel == 'BiLSTM': + init_embedding = batch_H.mean(dim=1) + else: + init_embedding = batch_H[:, 0, :] + h_0 = self.proj_init_h(init_embedding) + c_0 = self.proj_init_c(init_embedding) + hidden = (h_0, c_0) + else: + hidden = (torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device), + torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device)) + + if self.attn_type == 'coverage': + alpha_cum = torch.zeros(beam_size, batch_H.shape[1], 1, dtype=torch.float32, device=self.device) + self.attention_cell.reset_mem() + + k_prev_words = torch.LongTensor([[ATTN.START()]] * beam_size).to(self.device) + seqs = k_prev_words + targets = k_prev_words.squeeze(dim=-1) + top_k_scores = torch.zeros(beam_size, 1).to(self.device) + + if self.viz_attn: + seqs_alpha = torch.ones(beam_size, 1, batch_H.shape[1]).to(self.device) + + complete_seqs = list() + if self.viz_attn: + complete_seqs_alpha = list() + complete_seqs_scores = list() + + for step in range(num_steps): + embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets) + output, hidden, alpha = self.attention_cell(hidden, batch_H, embed_text) + output = self.dropout(output) + vocab_size = output.shape[1] + + scores = F.log_softmax(output, dim=-1) + scores = top_k_scores.expand_as(scores) + scores + if step == 0: + top_k_scores, top_k_words = scores[0].topk(beam_size, 0, True, True) + else: + top_k_scores, top_k_words = scores.view(-1).topk(beam_size, 0, True, True) + + prev_word_inds = top_k_words // vocab_size + next_word_inds = top_k_words % vocab_size + + seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) + if self.viz_attn: + seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].permute(0, 2, 1)], + dim=1) + + incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if + next_word != ATTN.END()] + + complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) + + if len(complete_inds) > 0: + complete_seqs.extend(seqs[complete_inds].tolist()) + if self.viz_attn: + complete_seqs_alpha.extend(seqs_alpha[complete_inds]) + complete_seqs_scores.extend(top_k_scores[complete_inds]) + + beam_size = beam_size - len(complete_inds) + if beam_size == 0: + break + + seqs = seqs[incomplete_inds] + if self.viz_attn: + seqs_alpha = seqs_alpha[incomplete_inds] + hidden = hidden[0][prev_word_inds[incomplete_inds]], \ + hidden[1][prev_word_inds[incomplete_inds]] + batch_H = batch_H[prev_word_inds[incomplete_inds]] + top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) + targets = next_word_inds[incomplete_inds] + + if self.attn_type == 'coverage': + alpha_cum = alpha_cum + alpha + alpha_cum = alpha_cum[incomplete_inds] + self.attention_cell.set_mem(alpha_cum) + elif self.attn_type == 'loc_aware': + self.attention_cell.set_mem(alpha) + + if len(complete_inds) == 0: + seq = seqs[0][1:].tolist() + seq = torch.LongTensor(seq).unsqueeze(0) + score = top_k_scores[0] + if self.viz_attn: + alphas = seqs_alpha[0][1:, ...] + return seq, score, alphas + else: + return seq, score, None + else: + combine_lst = tuple(zip(complete_seqs, complete_seqs_scores)) + best_ind = combine_lst.index(max(combine_lst, key=lambda x: x[1] / len(x[0]))) #https://youtu.be/XXtpJxZBa2c?t=2407 + seq = complete_seqs[best_ind][1:] #not include [GO] token + seq = torch.LongTensor(seq).unsqueeze(0) + score = max(complete_seqs_scores) + + if self.viz_attn: + alphas = complete_seqs_alpha[best_ind][1:, ...] + return seq, score, alphas + else: + return seq, score, None + + def forward_greedy(self, batch_H, text, is_train=True, is_test=False, batch_max_length=25): + batch_size = batch_H.size(0) + num_steps = batch_max_length + 1 + if self.enc_init: + if self.seqmodel == 'BiLSTM': + init_embedding = batch_H.mean(dim=1) + encoder_hidden = batch_H + else: + encoder_hidden = batch_H + init_embedding = batch_H[:, 0, :] + h_0 = self.proj_init_h(init_embedding) + c_0 = self.proj_init_c(init_embedding) + hidden = (h_0, c_0) + else: + encoder_hidden = batch_H + hidden = (torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device), + torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device)) + + targets = torch.zeros(batch_size, dtype=torch.long, device=self.device) # [GO] token + probs = torch.zeros(batch_size, num_steps, self.num_classes, dtype=torch.float32, device=self.device) + + if self.viz_attn: + self.alpha_stores = torch.zeros(batch_size, num_steps, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device) + if self.attn_type == 'coverage': + alpha_cum = torch.zeros(batch_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device) + + self.attention_cell.reset_mem() + + if is_test: + end_flag = torch.zeros(batch_size, dtype=torch.bool, device=self.device) + + for i in range(num_steps): + embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets) + output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text) + output = self.dropout(output) + if self.viz_attn: + self.alpha_stores[:, i] = alpha + if self.attn_type == 'coverage': + alpha_cum = alpha_cum + alpha + self.attention_cell.set_mem(alpha_cum) + elif self.attn_type == 'loc_aware': + self.attention_cell.set_mem(alpha) + + probs_step = output + probs[:, i, :] = probs_step + + if i == num_steps - 1: + break + + if is_train: + if self.teacher_forcing < random.random(): + _, next_input = probs_step.max(1) + targets = next_input + else: + targets = text[:, i+1] + else: + _, next_input = probs_step.max(1) + targets = next_input + + if is_test: + end_flag = end_flag | (next_input == ATTN.END()) + if end_flag.all(): + break + + _, preds_index = probs.max(2) + + return preds_index, probs, None # batch_size x num_steps x num_classes + + def forward(self, beam_size, batch_H, text, batch_max_length, is_train=True, is_test=False): + if is_train: + return self.forward_greedy(batch_H, text, is_train, is_test, batch_max_length) + else: + if beam_size > 1: + return self.forward_beam(batch_H, batch_max_length, beam_size) + else: + return self.forward_greedy(batch_H, text, is_train, is_test, batch_max_length) + diff --git a/HybridViT/module/component/prediction_head/seq2seq_v2.py b/HybridViT/module/component/prediction_head/seq2seq_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..20eec35c9de355720de2a5a1848f1a858b8485fd --- /dev/null +++ b/HybridViT/module/component/prediction_head/seq2seq_v2.py @@ -0,0 +1,218 @@ +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat +from ...converter import AttnLabelConverter as ATTN +from .addon_module import * +from .seq2seq import Attention + + +class AttentionV2(Attention): + def forward_beam( + self, + batch_H: torch.Tensor, + batch_max_length=25, + beam_size=4, + ): + batch_size = batch_H.size(0) + assert batch_size == 1 + num_steps = batch_max_length + 1 + batch_H = batch_H.squeeze(dim=0) + batch_H = repeat(batch_H, "s e -> b s e", b = beam_size) + + encoder_hidden = None + if self.seqmodel in ['BiLSTM', 'VIG']: + encoder_hidden = batch_H + elif self.seqmodel == 'TFM': + encoder_hidden = batch_H[:, 1:, :] + else: + raise ValueError('seqmodel must be either BiLSTM or TFM option') + + if self.enc_init: + init_embedding = None + if self.seqmodel in ['BiLSTM', 'VIG']: + init_embedding = batch_H.mean(dim=1) + elif self.seqmodel == 'TFM': + init_embedding = batch_H[:, 0, :] + else: + raise ValueError('seqmodel must be either BiLSTM or TFM option') + + assert init_embedding is not None + h_0 = self.proj_init_h(init_embedding) + c_0 = self.proj_init_c(init_embedding) + hidden = (h_0, c_0) + else: + hidden = (torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device), + torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device)) + + assert encoder_hidden is not None + + if self.attn_type == 'coverage': + alpha_cum = torch.zeros(beam_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device) + self.attention_cell.reset_mem() + + k_prev_words = torch.LongTensor([[ATTN.START()]] * beam_size).to(self.device) + seqs = k_prev_words + targets = k_prev_words.squeeze(dim=-1) + top_k_scores = torch.zeros(beam_size, 1).to(self.device) + + if self.viz_attn: + seqs_alpha = torch.ones(beam_size, 1, encoder_hidden.shape[1]).to(self.device) + + complete_seqs = list() + if self.viz_attn: + complete_seqs_alpha = list() + complete_seqs_scores = list() + + for step in range(num_steps): + embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets) + output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text) + output = self.dropout(output) + vocab_size = output.shape[1] + + scores = F.log_softmax(output, dim=-1) + scores = top_k_scores.expand_as(scores) + scores + if step == 0: + top_k_scores, top_k_words = scores[0].topk(beam_size, 0, True, True) + else: + top_k_scores, top_k_words = scores.view(-1).topk(beam_size, 0, True, True) + + prev_word_inds = top_k_words // vocab_size + next_word_inds = top_k_words % vocab_size + + seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) + if self.viz_attn: + seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].permute(0, 2, 1)], + dim=1) + + incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if + next_word != ATTN.END()] + + complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) + + if len(complete_inds) > 0: + complete_seqs.extend(seqs[complete_inds].tolist()) + if self.viz_attn: + complete_seqs_alpha.extend(seqs_alpha[complete_inds]) + complete_seqs_scores.extend(top_k_scores[complete_inds]) + + beam_size = beam_size - len(complete_inds) + if beam_size == 0: + break + + seqs = seqs[incomplete_inds] + if self.viz_attn: + seqs_alpha = seqs_alpha[incomplete_inds] + hidden = hidden[0][prev_word_inds[incomplete_inds]], \ + hidden[1][prev_word_inds[incomplete_inds]] + encoder_hidden = encoder_hidden[prev_word_inds[incomplete_inds]] + top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) + targets = next_word_inds[incomplete_inds] + + if self.attn_type == 'coverage': + alpha_cum = alpha_cum + alpha + alpha_cum = alpha_cum[incomplete_inds] + self.attention_cell.set_mem(alpha_cum) + elif self.attn_type == 'loc_aware': + self.attention_cell.set_mem(alpha) + + if len(complete_inds) == 0: + seq = seqs[0][1:].tolist() + seq = torch.LongTensor(seq).unsqueeze(0) + score = top_k_scores[0] + if self.viz_attn: + alphas = seqs_alpha[0][1:, ...] + return seq, score, alphas + else: + return seq, score, None + else: + combine_lst = tuple(zip(complete_seqs, complete_seqs_scores)) + best_ind = combine_lst.index(max(combine_lst, key=lambda x: x[1] / len(x[0]))) #https://youtu.be/XXtpJxZBa2c?t=2407 + seq = complete_seqs[best_ind][1:] #not include [GO] token + seq = torch.LongTensor(seq).unsqueeze(0) + score = max(complete_seqs_scores) + + if self.viz_attn: + alphas = complete_seqs_alpha[best_ind][1:, ...] + return seq, score, alphas + else: + return seq, score, None + + def forward_greedy(self, batch_H, text, is_train=True, is_test=False, batch_max_length=25): + batch_size = batch_H.size(0) + num_steps = batch_max_length + 1 + encoder_hidden = None + if self.seqmodel in ['BiLSTM', 'VIG']: + encoder_hidden = batch_H + elif self.seqmodel == 'TFM': + encoder_hidden = batch_H[:, 1:, :] + else: + raise ValueError('seqmodel must be either BiLSTM or TFM option') + + if self.enc_init: + init_embedding = None + if self.seqmodel in ['BiLSTM', 'VIG']: + init_embedding = batch_H.mean(dim=1) + elif self.seqmodel == 'TFM': + init_embedding = batch_H[:, 0, :] + else: + raise ValueError('seqmodel must be either BiLSTM or TFM option') + h_0 = self.proj_init_h(init_embedding) + c_0 = self.proj_init_c(init_embedding) + hidden = (h_0, c_0) + else: + hidden = (torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device), + torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device)) + + targets = torch.zeros(batch_size, dtype=torch.long, device=self.device) # [GO] token + probs = torch.zeros(batch_size, num_steps, self.num_classes, dtype=torch.float32, device=self.device) + + assert encoder_hidden is not None + + if self.viz_attn: + self.alpha_stores = torch.zeros(batch_size, num_steps, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device) + if self.attn_type == 'coverage': + alpha_cum = torch.zeros(batch_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device) + + self.attention_cell.reset_mem() + + if is_test: + end_flag = torch.zeros(batch_size, dtype=torch.bool, device=self.device) + + for i in range(num_steps): + embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets) + output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text) + output = self.dropout(output) + if self.viz_attn: + self.alpha_stores[:, i] = alpha + if self.attn_type == 'coverage': + alpha_cum = alpha_cum + alpha + self.attention_cell.set_mem(alpha_cum) + elif self.attn_type == 'loc_aware': + self.attention_cell.set_mem(alpha) + + probs_step = output + probs[:, i, :] = probs_step + + if i == num_steps - 1: + break + + if is_train: + if self.teacher_forcing < random.random(): + _, next_input = probs_step.max(1) + targets = next_input + else: + targets = text[:, i+1] + else: + _, next_input = probs_step.max(1) + targets = next_input + + if is_test: + end_flag = end_flag | (next_input == ATTN.END()) + if end_flag.all(): + break + + _, preds_index = probs.max(2) + + return preds_index, probs, None # batch_size x num_steps x num_classes diff --git a/HybridViT/module/component/prediction_head/tfm.py b/HybridViT/module/component/prediction_head/tfm.py new file mode 100644 index 0000000000000000000000000000000000000000..e62a5f1d7d4991d61a3ad5d74f16a2ad4e257478 --- /dev/null +++ b/HybridViT/module/component/prediction_head/tfm.py @@ -0,0 +1,207 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import FloatTensor, LongTensor +from .addon_module import WordPosEnc +from ...converter.tfm_converter import TFMLabelConverter as TFM +from ....beam import Beam + +def _build_transformer_decoder( + d_model: int, + nhead: int, + num_decoder_layers: int, + dim_feedforward: int, + dropout: float, +) -> nn.TransformerDecoder: + decoder_layer = nn.TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + + decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers) + + for p in decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + return decoder + + +class TransformerPrediction(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + num_decoder_layers: int, + dim_feedforward: int, + dropout: float, + num_classes: int, + max_seq_len: int, + padding_idx: int, + device: str = 'cuda:1' + ): + super().__init__() + self.max_seq_len = max_seq_len + self.padding_idx = padding_idx + self.num_classes = num_classes + self.device = device + self.word_embed = nn.Embedding( + num_classes, d_model, padding_idx=padding_idx + ) + + self.pos_enc = WordPosEnc(d_model=d_model) + self.d_model = d_model + self.model = _build_transformer_decoder( + d_model=d_model, + nhead=nhead, + num_decoder_layers=num_decoder_layers, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + + self.proj = nn.Linear(d_model, num_classes) + self.beam = Beam( + ignore_w=TFM.PAD(), + start_w=TFM.START(), + stop_w=TFM.END(), + max_len=self.max_seq_len, + device=self.device + ) + + def reset_beam(self): + self.beam = Beam( + ignore_w=TFM.PAD(), + start_w=TFM.START(), + stop_w=TFM.END(), + max_len=self.max_seq_len, + device=self.device + ) + + def _build_attention_mask(self, length): + mask = torch.full( + (length, length), + fill_value=1, + dtype=torch.bool, + device=self.device + ) + mask = torch.triu(mask).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + def _embedd_tgt(self, tgt: LongTensor, tgt_len: int): + tgt_mask = self._build_attention_mask(tgt_len) + if self.training: + tgt_pad_mask = tgt == self.padding_idx + else: tgt_pad_mask = None + tgt = self.word_embed(tgt) + tgt = self.pos_enc(tgt*math.sqrt(self.d_model)) + return tgt, tgt_mask, tgt_pad_mask + + def forward_greedy( + self, src: FloatTensor, tgt: LongTensor, output_weight: bool = False, is_test: bool = False + ) -> FloatTensor: + if self.training: + _, l = tgt.size() + tgt, tgt_mask, tgt_pad_mask = self._embedd_tgt(tgt, l) + + src = rearrange(src, "b t d -> t b d") + tgt = rearrange(tgt, "b l d -> l b d") + + out = self.model( + tgt=tgt, + memory=src, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_pad_mask + ) + + out = rearrange(out, "l b d -> b l d") + out = self.proj(out) + else: + out = None + src = rearrange(src, "b t d -> t b d") + + end_flag = torch.zeros(src.shape[0], dtype=torch.bool, device=self.device) + + for step in range(self.max_seq_len+1): + b, l = tgt.size() + emb_tgt, tgt_mask, tgt_pad_mask = self._embedd_tgt(tgt, l) + emb_tgt = rearrange(emb_tgt, "b l d -> l b d") + + out = self.model( + tgt=emb_tgt, + memory=src, + tgt_mask=tgt_mask + ) + + out = rearrange(out, "l b d -> b l d") + out = self.proj(out) + probs = F.softmax(out, dim=-1) + next_text = torch.argmax(probs[:, -1:, :], dim=-1) + tgt = torch.cat([tgt, next_text], dim=-1) + + end_flag = end_flag | (next_text == TFM.END()) + if end_flag.all() and is_test: + break + + _, preds_index = out.max(dim=2) + return preds_index, out + + def forward_beam(self, + src: torch.FloatTensor, + beam_size: int + ): + assert src.size(0) == 1, f'beam search should only have signle source, encounter with batch size: {src.size(0)}' + out = None + src = src.squeeze(0) + + for step in range(self.max_seq_len + 1): + hypotheses = self.beam.hypotheses + hyp_num = hypotheses.size(0) + l = hypotheses.size(1) + assert hyp_num <= beam_size, f"hyp_num: {hyp_num}, beam_size: {beam_size}" + + emb_tgt = self.word_embed(hypotheses) + emb_tgt = self.pos_enc(emb_tgt*math.sqrt(self.d_model)) + tgt_mask = self._build_attention_mask(l) + emb_tgt = rearrange(emb_tgt, "b l d -> l b d") + + exp_src = repeat(src.squeeze(1), "s e -> s b e", b=hyp_num) + + out = self.model( + tgt=emb_tgt, + memory=exp_src, + tgt_mask=tgt_mask + ) + + out = rearrange(out, "l b d -> b l d") + out = self.proj(out) + log_prob = F.log_softmax(out[:, step, :], dim=-1) + new_hypotheses, new_hyp_scores = self.beam.advance(log_prob, step, beam_size=beam_size) + + if self.beam.done(beam_size): + break + + self.beam.set_current_state(new_hypotheses) + self.beam.set_current_score(new_hyp_scores) + + self.beam.set_hypothesis() + best_hyp = max(self.beam.completed_hypotheses, key=lambda h: h.score / len(h)) + output = best_hyp.seq + output = torch.LongTensor(output).unsqueeze(0) + score = best_hyp.score + + return output, score + + def forward(self, beam_size, batch_H, text, is_test): + if self.training: + return self.forward_greedy(batch_H, text) + else: + if beam_size > 1: + return self.forward_beam(batch_H, beam_size) + else: + return self.forward_greedy(batch_H, text, is_test = is_test) + diff --git a/HybridViT/module/component/seq_modeling/__init__.py b/HybridViT/module/component/seq_modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f86fbf5890ae0ec641b6dbb61dbfdbfd1c4c98f --- /dev/null +++ b/HybridViT/module/component/seq_modeling/__init__.py @@ -0,0 +1,2 @@ +from .bilstm import * +from .vit_encoder import * diff --git a/HybridViT/module/component/seq_modeling/addon_module/__init__.py b/HybridViT/module/component/seq_modeling/addon_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2891afee2178461b0a6b62ba12544dc0222b127 --- /dev/null +++ b/HybridViT/module/component/seq_modeling/addon_module/__init__.py @@ -0,0 +1 @@ +from .patchembed import * diff --git a/HybridViT/module/component/seq_modeling/addon_module/patchembed.py b/HybridViT/module/component/seq_modeling/addon_module/patchembed.py new file mode 100644 index 0000000000000000000000000000000000000000..c35148e12b6522c3ce7de660da4f0d77d501234f --- /dev/null +++ b/HybridViT/module/component/seq_modeling/addon_module/patchembed.py @@ -0,0 +1,161 @@ +import math +import torch.nn as nn +import torch +from torch.nn import functional as F +from timm.models.layers.helpers import to_2tuple +from typing import Tuple, Union, List + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=(224, 224), patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(img_size, tuple) + patch_size = to_2tuple(patch_size) + div_h, mod_h = divmod(img_size[0], patch_size[0]) + div_w, mod_w = divmod(img_size[1], patch_size[1]) + self.img_size = (patch_size[0]*(div_h + (1 if mod_h > 0 else 0)), \ + patch_size[1]*(div_w + (1 if mod_w > 0 else 0))) + self.grid_size = (self.img_size[0] // patch_size[0], self.img_size[1] // patch_size[1]) + self.patch_size = patch_size + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + _, _, H, W = x.shape + div_h, mod_h = divmod(H, self.patch_size[0]) + div_w, mod_w = divmod(W, self.patch_size[1]) + pad_H =self.patch_size[0]*(div_h + (1 if mod_h > 0 else 0)) - H + pad_W = self.patch_size[1]*(div_w + (1 if mod_w > 0 else 0)) - W + x = F.pad(x, (0, pad_W, 0 , pad_H)) + assert x.shape[2] % self.patch_size[0] == 0 and x.shape[3] % self.patch_size[1] == 0 + proj_x = self.proj(x).flatten(2).transpose(1, 2) + return proj_x, {'height': x.shape[2], 'width': x.shape[3]}, (x.shape[2] != self.img_size[0] or x.shape[3] != self.img_size[1]) + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size: Tuple[int], patch_size=Union[List, int], feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + if isinstance(patch_size, int): + patch_size = to_2tuple(patch_size) + else: + patch_size = tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # NOTE Most reliable way of determining output dims is to run forward pass + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + + assert feature_size[0] >= patch_size[0] and feature_size[1] >= patch_size[1] + + div_h, mod_h = divmod(feature_size[0], patch_size[0]) + div_w, mod_w = divmod(feature_size[1], patch_size[1]) + + self.feature_size = (patch_size[0]*(div_h + (1 if mod_h > 0 else 0)), patch_size[1]*(div_w + (1 if mod_w > 0 else 0))) + assert self.feature_size[0] % patch_size[0] == 0 and self.feature_size[1] % patch_size[1] == 0 + self.grid_size = (self.feature_size[0] // patch_size[0], self.feature_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + origin_size = x.shape[-2:] + x = self.backbone(x) + f_h, f_w = x.shape[2:] + # assert f_h >= self.patch_size[0] and f_w >= self.patch_size[1] + + div_h, mod_h = divmod(f_h, self.patch_size[0]) + div_w, mod_w = divmod(f_w, self.patch_size[1]) + + pad_H =self.patch_size[0]*(div_h + (1 if mod_h > 0 else 0)) - f_h + pad_W = self.patch_size[1]*(div_w + (1 if mod_w > 0 else 0)) - f_w + x = F.pad(x, (0, pad_W, 0 , pad_H)) + + assert x.shape[2] % self.patch_size[0] == 0 and x.shape[3] % self.patch_size[1] == 0 + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + proj_x = self.proj(x).flatten(2).transpose(1, 2) + return proj_x, (pad_W, pad_H), {'height': x.shape[2], 'width': x.shape[3]}, (x.shape[2] != self.feature_size[0] or x.shape[3] != self.feature_size[1]) + +class HybridEmbed1D(nn.Module): + """ CNN Feature Map Embedding which using 1D embed patching + from https://arxiv.org/pdf/2111.08314.pdf, which benefits for text recognition task.Check paper for more detail + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size: Tuple[int], feature_size=None, patch_size=1, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + self.img_size = img_size + self.backbone = backbone + self.embed_dim = embed_dim + if feature_size is None: + with torch.no_grad(): + # NOTE Most reliable way of determining output dims is to run forward pass + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + + self.window_width = patch_size + assert feature_size[1] >= self.window_width + div_w, mod_w = divmod(feature_size[1], self.window_width) + self.feature_size = (feature_size[0], self.window_width*(div_w + (1 if mod_w > 0 else 0))) + assert self.feature_size[1] % self.window_width == 0 + self.grid_size = (1, self.feature_size[1] // self.window_width) + self.num_patches = self.grid_size[1] + self.proj = nn.Conv1d(feature_dim, embed_dim, kernel_size=self.window_width, stride=self.window_width, bias=True) + + def forward(self, x): + batch_size = x.shape[0] + x = self.backbone(x) + f_h, f_w = x.shape[2:] + assert f_w >= self.window_width + + div_w, mod_w = divmod(f_w, self.window_width) + pad_W = self.window_width*(div_w + (1 if mod_w > 0 else 0)) - f_w + + x = F.pad(x, (0, pad_W)) + assert x.shape[3] % self.window_width == 0 + + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + proj_x = torch.zeros(batch_size, self.embed_dim, f_h, x.shape[3]//self.window_width, device=x.device, dtype=x.dtype) + + for i in range(f_h): + proj = self.proj(x[:, :, i, :]) + proj_x[:, :, i, :] = proj + + proj_x = proj_x.mean(dim=2).transpose(1, 2) #BCHW->BCW + + return proj_x, (pad_W, ), {'height': x.shape[2], 'width': x.shape[3]}, (x.shape[2] != self.feature_size[0] or x.shape[3] != self.feature_size[1]) \ No newline at end of file diff --git a/HybridViT/module/component/seq_modeling/bilstm.py b/HybridViT/module/component/seq_modeling/bilstm.py new file mode 100644 index 0000000000000000000000000000000000000000..1d4e5b360b2b41435a8b3c78737acfd0d14bb398 --- /dev/null +++ b/HybridViT/module/component/seq_modeling/bilstm.py @@ -0,0 +1,33 @@ +import torch.nn as nn + +__all__ = ['BiLSTM_Seq_Modeling', 'BidirectionalLSTM'] + +class BidirectionalLSTM(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(BidirectionalLSTM, self).__init__() + self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) + self.linear = nn.Linear(hidden_size * 2, output_size) + + def forward(self, input): + """ + input : visual feature [batch_size x T x input_size] + output : contextual feature [batch_size x T x output_size] + """ + self.rnn.flatten_parameters() + recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) + output = self.linear(recurrent) # batch_size x T x output_size + return output + +class BiLSTM_Seq_Modeling(nn.Module): + def __init__(self, num_layers, input_size, hidden_size, output_size): + super(BiLSTM_Seq_Modeling, self).__init__() + self.num_layers = num_layers + layers = [] + layers += [BidirectionalLSTM(input_size, hidden_size, hidden_size)] + for i in range(num_layers-2): + layers.append(BidirectionalLSTM(hidden_size, hidden_size, hidden_size)) + layers.append(BidirectionalLSTM(hidden_size, hidden_size, output_size)) + self.lstm = nn.Sequential(*layers) + + def forward(self, input): + return self.lstm(input) \ No newline at end of file diff --git a/HybridViT/module/component/seq_modeling/vit/utils.py b/HybridViT/module/component/seq_modeling/vit/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..842796b84f9418c8b6d60c51f09a219ef7347c70 --- /dev/null +++ b/HybridViT/module/component/seq_modeling/vit/utils.py @@ -0,0 +1,59 @@ +import torch +import math +import warnings + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/HybridViT/module/component/seq_modeling/vit/vision_transformer.py b/HybridViT/module/component/seq_modeling/vit/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ed828b0f098edfc7f811076831676d35eb6eada1 --- /dev/null +++ b/HybridViT/module/component/seq_modeling/vit/vision_transformer.py @@ -0,0 +1,184 @@ +import torch +from torch import nn +from functools import partial +from collections import OrderedDict +from ...common import DropPath +from .utils import trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ConvFFN(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = None + num_patches = getattr(self.patch_embed, 'num_patches', 128) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Representation layer + if representation_size: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + # Classifier head + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x)[:, 0] + x = self.pre_logits(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x diff --git a/HybridViT/module/component/seq_modeling/vit_encoder.py b/HybridViT/module/component/seq_modeling/vit_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..36f53ce4864e5a358f6bd1a30a705267ae5a58fc --- /dev/null +++ b/HybridViT/module/component/seq_modeling/vit_encoder.py @@ -0,0 +1,276 @@ +import torch.nn as nn +import torch +from torch.nn import functional as F +from .vit.utils import trunc_normal_ +from .vit.vision_transformer import VisionTransformer +from ..feature_extractor.clova_impl import ResNet_FeatureExtractor +from .addon_module import * +from ..common.mae_posembed import get_2d_sincos_pos_embed + +__all__ = ['ViTEncoder', 'ViTEncoderV2', 'ViTEncoderV3', 'TRIGBaseEncoder', 'create_vit_modeling'] + +class ViTEncoder(VisionTransformer): + ''' + ''' + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if kwargs['hybrid_backbone'] is None: + self.patch_embed = PatchEmbed( + img_size=kwargs['img_size'], + in_chans=kwargs['in_chans'], + patch_size=kwargs['patch_size'], + embed_dim=kwargs['embed_dim'], + ) + else: + self.patch_embed = HybridEmbed( + backbone=kwargs['hybrid_backbone'], + img_size=kwargs['img_size'], + in_chans=kwargs['in_chans'], + patch_size=kwargs['patch_size'], + embed_dim=kwargs['embed_dim'], + ) + num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim'])) + self.emb_height = self.patch_embed.grid_size[0] + self.emb_width = self.patch_embed.grid_size[1] + trunc_normal_(self.pos_embed, std=.02) + self.apply(self._init_weights) + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def interpolating_pos_embedding(self, embedding, height, width): + """ + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + npatch = embedding.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and height == width: + return self.pos_embed + + class_pos_embedding = self.pos_embed[:, 0] + patch_pos_embedding = self.pos_embed[:, 1:] + dim = self.pos_embed.shape[-1] + + h0 = height // self.patch_embed.patch_size[0] + w0 = width // self.patch_embed.patch_size[1] + #add a small number to avo_id floating point error + # https://github.com/facebookresearch/dino/issues/8 + + h0 = h0 + 0.1 + w0 = w0 + 0.1 + + patch_pos_embedding = nn.functional.interpolate( + patch_pos_embedding.reshape(1, self.emb_height, self.emb_width, dim).permute(0, 3, 1, 2), + scale_factor=(h0 / self.emb_height, w0 / self.emb_width), + mode='bicubic', + align_corners=False + ) + assert int(h0) == patch_pos_embedding.shape[-2] and int(w0) == patch_pos_embedding.shape[-1] + patch_pos_embedding = patch_pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) + class_pos_embedding = class_pos_embedding.unsqueeze(0) + + return torch.cat((class_pos_embedding, patch_pos_embedding), dim=1) + + def forward_features(self, x): + B, C, _, _ = x.shape + + x, pad_info, size, interpolating_pos = self.patch_embed(x) + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if interpolating_pos: + x = x + self.interpolating_pos_embedding(x, size['height'], size['width']) + else: + x = x + self.pos_embed + + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x, pad_info, size + + +class TRIGBaseEncoder(ViTEncoder): + ''' + https://arxiv.org/pdf/2111.08314.pdf + ''' + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.patch_embed = HybridEmbed1D( + backbone=kwargs['hybrid_backbone'], + img_size=kwargs['img_size'], + in_chans=kwargs['in_chans'], + patch_size=kwargs['patch_size'], + embed_dim=kwargs['embed_dim'], + ) + num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim'])) + self.emb_height = 1 + self.emb_width = self.patch_embed.grid_size[1] + trunc_normal_(self.pos_embed, std=.02) + self.apply(self._init_weights) + + def interpolating_pos_embedding(self, embedding, height, width): + """ + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + npatch = embedding.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and height == width: + return self.pos_embed + + class_pos_embedding = self.pos_embed[:, 0] + patch_pos_embedding = self.pos_embed[:, 1:] + dim = self.pos_embed.shape[-1] + + w0 = width // self.patch_embed.window_width + + #add a small number to avoid floating point error + # https://github.com/facebookresearch/dino/issues/8 + + w0 = w0 + 0.1 + + patch_pos_embedding = nn.functional.interpolate( + patch_pos_embedding.reshape(1, self.emb_height, self.emb_width, dim).permute(0, 3, 1, 2), + scale_factor=(1, w0 / self.emb_width), + mode='bicubic', + align_corners=False + ) + + assert int(w0) == patch_pos_embedding.shape[-1] + patch_pos_embedding = patch_pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) + class_pos_embedding = class_pos_embedding.unsqueeze(0) + + return torch.cat((class_pos_embedding, patch_pos_embedding), dim=1) + + def forward_features(self, x): + B, _, _, _ = x.shape + x, padinfo, size, interpolating_pos = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + + x = torch.cat((cls_tokens, x), dim=1) #cls_tokens is init_embedding in TRIG paper + + if interpolating_pos: + x = x + self.interpolating_pos_embedding(x, size['height'], size['width']) + else: + x = x + self.pos_embed + + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x, padinfo, size + + +class ViTEncoderV2(ViTEncoder): + def forward(self, x): + B, _, _, _ = x.shape + + x, pad_info, size, _ = self.patch_embed(x) + _, numpatches, *_ = x.shape + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:, :(numpatches + 1)] + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x, pad_info, size + +class ViTEncoderV3(ViTEncoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if hasattr(self, 'pos_embed'): + del self.pos_embed + num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']), requires_grad=False) + self.initialize_posembed() + + def initialize_posembed(self): + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + self.patch_embed.grid_size[0], + self.patch_embed.grid_size[1], + cls_token=True + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + def forward(self, x): + B, _, _, _ = x.shape + + x, pad_info, size, _ = self.patch_embed(x) + _, numpatches, *_ = x.shape + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:, :(numpatches + 1)] + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x, pad_info, size + +def create_vit_modeling(opt): + seq_modeling = opt['SequenceModeling']['params'] + if seq_modeling['backbone'] is not None: + if seq_modeling['backbone']['name'] == 'resnet': + param_kwargs = dict() + if seq_modeling['backbone'].get('pretrained', None) is not None: + param_kwargs['pretrained'] = seq_modeling['backbone']['pretrained'] + if seq_modeling['backbone'].get('weight_dir', None) is not None: + param_kwargs['weight_dir'] = seq_modeling['backbone']['weight_dir'] + print('kwargs', param_kwargs) + + backbone = ResNet_FeatureExtractor( + seq_modeling['backbone']['input_channel'], + seq_modeling['backbone']['output_channel'], + seq_modeling['backbone']['gcb'], + **param_kwargs + ) + elif seq_modeling['backbone']['name'] == 'cnn': + backbone = None + else: backbone = None + max_dimension = (opt['imgH'], opt['max_dimension'][1]) if opt['imgH'] else opt['max_dimension'] + if seq_modeling['patching_style'] == '2d': + if seq_modeling.get('fix_embed', False): + encoder = ViTEncoderV3 + else: + if not seq_modeling.get('interpolate_embed', True): + encoder = ViTEncoderV2 + else: + encoder = ViTEncoder + else: + encoder = TRIGBaseEncoder + + encoder_seq_modeling = encoder( + img_size=max_dimension, + patch_size=seq_modeling['patch_size'], + in_chans=seq_modeling['input_channel'], + depth=seq_modeling['depth'], + num_classes=0, + embed_dim=seq_modeling['hidden_size'], + num_heads=seq_modeling['num_heads'], + hybrid_backbone=backbone + ) + + return encoder_seq_modeling diff --git a/HybridViT/module/converter/__init__.py b/HybridViT/module/converter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00fc5ab8375cbb78fdca2e9b6a1eda0af3de1de3 --- /dev/null +++ b/HybridViT/module/converter/__init__.py @@ -0,0 +1,3 @@ +from .builder import create_converter +from .attn_converter import AttnLabelConverter +from .tfm_converter import TFMLabelConverter \ No newline at end of file diff --git a/HybridViT/module/converter/attn_converter.py b/HybridViT/module/converter/attn_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..9e00de083ad4f790fd8e4fe0c32308777c5d2d9c --- /dev/null +++ b/HybridViT/module/converter/attn_converter.py @@ -0,0 +1,71 @@ +import torch +import numpy as np + +class AttnLabelConverter(object): + """ Convert between text-label and text-index """ + + list_token = ['[GO]', '[s]', '[UNK]'] + def __init__(self, character, device): + list_character = character + self.character = AttnLabelConverter.list_token + list_character + + self.device = device + self.dict = {} + for i, char in enumerate(self.character): + self.dict[char] = i + self.ignore_idx = self.dict['[GO]'] + + @staticmethod + def START() -> int: + return AttnLabelConverter.list_token.index('[GO]') + + @staticmethod + def END() -> int: + return AttnLabelConverter.list_token.index('[s]') + + @staticmethod + def UNK() -> int: + return AttnLabelConverter.list_token.index('[UNK]') + + def encode(self, text, batch_max_length=25): + length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. + # batch_max_length = max(length) # this is not allowed for multi-gpu setting + batch_max_length += 1 + # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. + batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) + for i, t in enumerate(text): + text = list(t) + + if len(text) > batch_max_length: + text = text[:(batch_max_length-1)] + + text.append('[s]') + text = [self.dict[char] if char in self.dict else self.dict['[UNK]'] for char in text] + + batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token + return (batch_text.to(self.device), torch.IntTensor(length).to(self.device)) + + def decode(self, text_index, token_level='word'): + """ convert text-index into text-label. """ + texts = [] + batch_size = text_index.shape[0] + for index in range(batch_size): + if token_level == 'word': + text = ' '.join([self.character[i] for i in text_index[index, :]]) + else: + text = ''.join([self.character[i] for i in text_index[index, :]]) + texts.append(text) + return texts + + def detokenize(self, token_ids): + """convert token ids to list of token""" + b_toks = [] + for tok in token_ids: + toks = [] + for i in tok: + if self.character[i] == '[s]': + break + toks.append(self.character[i]) + b_toks.append(toks) + + return b_toks \ No newline at end of file diff --git a/HybridViT/module/converter/builder.py b/HybridViT/module/converter/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0466749d19d75576d1bde83523c3293f77e6e8d0 --- /dev/null +++ b/HybridViT/module/converter/builder.py @@ -0,0 +1,6 @@ +from .attn_converter import AttnLabelConverter + +def create_converter(config, device): + if 'Attn' in config['Prediction']['name']: + converter = AttnLabelConverter(config['character'], device) + return converter \ No newline at end of file diff --git a/HybridViT/module/converter/tfm_converter.py b/HybridViT/module/converter/tfm_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..4f18575f887608fd39cd0d67d483158111fb3ef9 --- /dev/null +++ b/HybridViT/module/converter/tfm_converter.py @@ -0,0 +1,90 @@ +import torch +import numpy as np + + +class TFMLabelConverter(object): + """ Convert between text-label and text-index """ + + list_token = ['[PAD]', '[GO]', '[s]', '[UNK]'] + def __init__(self, character, device): + list_character = character + self.character = TFMLabelConverter.list_token + list_character + + self.device = device + self.dict = {} + for i, char in enumerate(self.character): + self.dict[char] = i + self.ignore_idx = self.dict['[PAD]'] + + @staticmethod + def START() -> int: + return TFMLabelConverter.list_token.index('[GO]') + + @staticmethod + def END() -> int: + return TFMLabelConverter.list_token.index('[s]') + + @staticmethod + def UNK() -> int: + return TFMLabelConverter.list_token.index('[UNK]') + + @staticmethod + def PAD() -> int: + return TFMLabelConverter.list_token.index('[PAD]') + + def encode(self, text, batch_max_length=25): + length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. + batch_max_length += 1 + batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(self.ignore_idx) + for i, t in enumerate(text): + text = list(t) + + if len(text) > batch_max_length: + text = text[:(batch_max_length-1)] + + text.append('[s]') + text = [self.dict[char] if char in self.dict else self.dict['[UNK]'] for char in text] + batch_text[i][0] = torch.LongTensor([self.dict['[GO]']]) + batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token + return (batch_text.to(self.device), torch.IntTensor(length).to(self.device)) + + def decode(self, text_index, token_level='word'): + """ convert text-index into text-label. """ + texts = [] + batch_size = text_index.shape[0] + for index in range(batch_size): + if token_level == 'word': + text = ' '.join([self.character[i] for i in text_index[index, :]]) + else: + text = ''.join([self.character[i] for i in text_index[index, :]]) + texts.append(text) + return texts + + def detokenize(self, token_ids): + """convert token ids to list of token""" + b_toks = [] + for tok in token_ids: + toks = [] + for i in tok: + if self.character[i] == '[s]': + break + toks.append(self.character[i]) + b_toks.append(toks) + + return b_toks + +if __name__ == '__main__': + vocab = ['S', 'ố', ' ', '2', '5', '3', 'đ', 'ư', 'ờ', 'n', 'g', 'T', 'r', 'ầ', 'P', 'h', 'ú', ',', 'ị', 't', 'ấ', 'N', 'a', 'm', 'á', 'c', 'H', 'u', 'y', 'ệ', 'ả', 'i', 'D', 'ơ', '8', '9', 'Đ', 'B', 'ộ', 'L', 'ĩ', '6', 'Q', 'ậ', 'ì', 'ạ', 'ồ', 'C', 'í', 'M', '4', 'E', '/', 'K', 'p', '1', 'A', 'x', 'ặ', 'ễ', '0', 'â', 'à', 'ế', 'ừ', 'ê', '-', '7', 'o', 'V', 'ô', 'ã', 'G', 'ớ', 'Y', 'I', 'ề', 'ò', 'l', 'R', 'ỹ', 'ủ', 'X', "'", 'e', 'ắ', 'ổ', 'ằ', 'k', 's', '.', 'ợ', 'ù', 'ứ', 'ă', 'ỳ', 'ẵ', 'ý', 'ó', 'ẩ', 'ọ', 'J', 'ũ', 'ữ', 'ự', 'õ', 'ỉ', 'ỏ', 'v', 'd', 'Â', 'W', 'U', 'O', 'é', 'ở', 'ỷ', '(', ')', 'ử', 'è', 'ể', 'ụ', 'ỗ', 'F', 'q', 'ẻ', 'ỡ', 'b', 'ỵ', 'Ứ', '#', 'ẽ', 'Ô', 'Ê', 'Ơ', '+', 'z', 'Ấ', 'w', 'Z', '&', 'Á', '~', 'f', 'Ạ', 'Ắ', 'j', ':', 'Ă', '<', '>', 'ẹ', '_', 'À', 'Ị', 'Ư', 'Ễ'] + text = [ + "190B Trần Quang Khải, Phường Tân Định, Quận 1, TP Hồ Chí Minh", + "164/2B, Quốc lộ 1A, Phường Lê Bình, Quận Cái Răng, Cần Thơ", + "Cẩm Huy, Huyện Cẩm Xuyên, Hà Tĩnh" + ] + tfm_convert = TFMLabelConverter(vocab, 'cpu') + texts, lengths = tfm_convert.encode(text, 70) + print(texts) + for text in texts: + print('Encode', text) + text = text.unsqueeze(0) + decode_text = tfm_convert.decode(text, 'char') + print('Decode', decode_text) \ No newline at end of file diff --git a/HybridViT/recog_flow.py b/HybridViT/recog_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..0186d84e77e65fe192eb24bf74134e2f7cadd601 --- /dev/null +++ b/HybridViT/recog_flow.py @@ -0,0 +1,113 @@ +import re +from typing import Any +from collections import OrderedDict +import re + +import torch +import torch.nn as nn +from torch.nn import functional as F +from PIL import Image +from timm.models.resnetv2 import ResNetV2 + +from .recognizers.build_model import Model +from .module.converter import AttnLabelConverter, TFMLabelConverter +from .helper import resize + +class MathRecognition(object): + def __init__(self, config, resizer): + self.args = config + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.args["device"] = device + self.device = device + self._prepare_vocab() + self.model = self._get_model() + self.resizer=resizer + + def _mapping_ckpt(self, state_dict): + new_state_dict = OrderedDict() + + for name, param in state_dict.items(): + if name.startswith('Transformation'): + continue + elif name.startswith('FeatureExtraction'): + new_name = name.replace('FeatureExtraction', 'featextractor.FeatureExtraction') + new_state_dict[new_name] = param + elif name.startswith('SequenceModeling'): + new_name = name.replace('SequenceModeling', 'seqmodeler.SequenceModeling') + new_state_dict[new_name] = param + elif name.startswith('Prediction'): + new_name = name.replace('Prediction', 'predicter.Prediction') + new_state_dict[new_name] = param + else: + new_state_dict[name] = param + + return new_state_dict + + def _get_model(self): + model = Model(self.args) + state_dict = torch.load(self.args["weight_path"], map_location='cpu') + new_state_dict = self._mapping_ckpt(state_dict) + model.load_state_dict(new_state_dict) + model=model.eval() + + if self.device == 'cuda': + num_gpu = torch.cuda.device_count() + if num_gpu > 1: + model = nn.DataParallel(model).to(self.device) + else: + model.to(self.device) + + return model + + def _prepare_vocab(self): + with open(self.args["vocab"], 'rt') as f: + for line in f: + self.args["character"] += [line.rstrip()] + f.close() + + if 'Attn' in self.args['Prediction']['name']: + self.converter = AttnLabelConverter(self.args["character"], self.device) + else: + self.converter = TFMLabelConverter(self.args["character"], self.device) + + self.args["num_class"] = len(self.converter.character) + + def _preprocess(self, image: Image.Image): + img_tensor = resize(self.resizer, image, self.args) + return img_tensor + + def _postprocess(self, s: str): + text_reg = r'(\\(operatorname|mathrm|mathbf|mathsf|mathit|mathfrak|mathnormal)\s?\*? {.*?})' + letter = '[a-zA-Z]' + noletter = '[\W_^\d]' + names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)] + s = re.sub(text_reg, lambda match: str(names.pop(0)), s) + news = s + + for space in ["hspace", "vspace"]: + match = re.finditer(space + " {(.*?)}", news) + if match: + new_l = "" + last = 0 + for m in match: + new_l = new_l + news[last:m.start(1)] + m.group(1).replace(" ", "") + last = m.end(1) + new_l = new_l + news[last:] + news = new_l + + return news + + def __call__(self, image: Image.Image, name=None, *arg: Any, **kwargs): + assert image.mode == 'RGB', 'input image must be RGB image' + with torch.no_grad(): + img_tensor = self._preprocess(image).to(self.device) + text_for_pred = torch.LongTensor(1, self.args["batch_max_length"] + 1).fill_(0).to(self.device) + preds_index, _, _ = self.model(img_tensor, text_for_pred, is_train=False, is_test=True) + pred_str = self.converter.decode(preds_index, self.args.get('token_level', 'word'))[0] + + pred_EOS = pred_str.find('[s]') + pred_str = pred_str[:pred_EOS] + + process_str = self._postprocess(pred_str) + + return process_str diff --git a/HybridViT/recognizers/__init__.py b/HybridViT/recognizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/HybridViT/recognizers/build_feat.py b/HybridViT/recognizers/build_feat.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc781151a476253ee28af01dc84a6af2961a325 --- /dev/null +++ b/HybridViT/recognizers/build_feat.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn +from ..module.component.feature_extractor.clova_impl import VGG_FeatureExtractor, ResNet_FeatureExtractor + + +class FeatExtractorBuilder(nn.Module): + def __init__(self, flow: dict, config): + super().__init__() + self.config = config + self.flow = flow + self.feat_name = flow['Feat'] + + if self.feat_name != 'None': + mean_height = config['FeatureExtraction']['params'].pop('mean_height', True) + + if self.feat_name == 'VGG': + self.FeatureExtraction = VGG_FeatureExtractor(**config['FeatureExtraction']['params']) + self.FeatureExtraction_output = config['FeatureExtraction']['params']['output_channel'] + elif self.feat_name == 'ResNet': + self.FeatureExtraction = ResNet_FeatureExtractor(**config['FeatureExtraction']['params']) + self.FeatureExtraction_output = config['FeatureExtraction']['params']['output_channel'] + + if mean_height: + self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 + else: + self.proj_feat = nn.Linear(self.FeatureExtraction_output*3, self.FeatureExtraction_output) + else: + if flow['Seq'] not in ['ViT', 'MS_ViT', 'MS_ViTV2', 'MS_ViTV3', 'ViG']: + raise Exception('No FeatureExtraction module specified') + else: + self.FeatureExtraction = nn.Identity() + + def forward(self, input): + visual_feature = self.FeatureExtraction(input) + + if self.flow['Seq'] in ['BiLSTM', 'BiLSTM_3L']: + if hasattr(self, 'AdaptiveAvgPool'): + visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, 1] + visual_feature = visual_feature.squeeze(3) + else: + visual_feature = visual_feature.permute(0, 3, 1, 2) + visual_feature = visual_feature.flatten(start_dim=-2) # [b, c, h, w] -> [b, w, c*h] + visual_feature = self.proj_feat(visual_feature) + + return visual_feature diff --git a/HybridViT/recognizers/build_model.py b/HybridViT/recognizers/build_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e4be9560c20c1291279a16c309a1add19d77fe --- /dev/null +++ b/HybridViT/recognizers/build_model.py @@ -0,0 +1,82 @@ +import torch.nn as nn +from .build_feat import FeatExtractorBuilder +from .build_seq import SeqModelingBuilder +from .build_pred import PredictBuilder + +class Model(nn.Module): + def __init__(self, opt): + super(Model, self).__init__() + self.opt = opt + + stages = { + 'Feat': opt['FeatureExtraction']['name'], + 'Seq': opt['SequenceModeling']['name'], + 'Pred': opt['Prediction']['name'], + } + self.stages = stages + if stages['Seq'].__contains__("Vi"): assert stages['Feat'] == 'None' + + """ FeatureExtraction """ + self.featextractor = FeatExtractorBuilder(stages, opt) + FeatureExtraction_output = getattr(self.featextractor, 'FeatureExtraction_output', None) + + """ Sequence modeling""" + self.seqmodeler = SeqModelingBuilder(stages, opt, FeatureExtraction_output) + SequenceModeling_output = getattr(self.seqmodeler, 'SequenceModeling_output', None) + + """ Prediction """ + self.predicter = PredictBuilder(stages, opt, SequenceModeling_output) + + def forward_encoder(self, input, *args, **kwargs): + """ Feature extraction stage """ + visual_feature = self.featextractor(input) + """ Sequence modeling stage """ + contextual_feature, output_shape , feat_pad = self.seqmodeler(visual_feature, *args, **kwargs) + return contextual_feature, output_shape, feat_pad + + def forward_decoder(self, contextual_feature, text, is_train=True, + is_test=False, rtl_text=None + ): + """ Prediction stage """ + prediction, logits, decoder_attn, addition_outputs = self.predicter(contextual_feature, text, is_train, is_test, rtl_text) + + return prediction, logits, decoder_attn, addition_outputs + + def forward(self, input, text, is_train=True, is_test=False, rtl_text=None): + contextual_feature, output_shape, feat_pad = self.forward_encoder(input) + prediction, logits, decoder_attn, addition_outputs = self.forward_decoder( + contextual_feature, text=text, is_train=is_train, is_test=is_test, rtl_text=rtl_text + ) + + if decoder_attn is not None and output_shape is not None: + if self.stages['Pred'] == 'Attn' and self.stages['Seq'] == 'ViT': + decoder_attn = decoder_attn[:, 1:] + decoder_attn = decoder_attn.reshape(-1, output_shape[0], output_shape[1]) + + addition_outputs.update( + { + 'decoder_attn': decoder_attn, + 'feat_width': output_shape[0], + 'feat_height': output_shape[1], + 'feat_pad': feat_pad, + } + ) + + return prediction, logits, addition_outputs + +if __name__ == '__main__': + import yaml + import torch + + with open('/media/huynhtruc0309/DATA/Math_Expression/my_source/Math_Recognition/config/train/paper_experiments/report_paper/best_model_noaugment/experiment_1806.yaml', 'r') as f: + config = yaml.safe_load(f) + + config['num_class'] = 499 + config['device'] = 'cpu' + model = Model(config) + + a = torch.rand(1, 1, 32, 224) + + output = model(a) + + print('pred', output[0].shape) diff --git a/HybridViT/recognizers/build_pred.py b/HybridViT/recognizers/build_pred.py new file mode 100644 index 0000000000000000000000000000000000000000..8caa1be4de6b79477b871a6d00b13ed515c4d3fd --- /dev/null +++ b/HybridViT/recognizers/build_pred.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +from ..module.component.prediction_head import (Attention, + AttentionV2, + TransformerPrediction + ) + + +class PredictBuilder(nn.Module): + def __init__(self, flow, config, SequenceModeling_output): + super().__init__() + self.flow =flow + self.config=config + if flow['Pred'] == 'CTC': + self.Prediction = nn.Linear(SequenceModeling_output, config['num_class']) + + elif flow['Pred'] == 'Attn': + config['Prediction']['params']['num_classes'] = config['num_class'] + config['Prediction']['params']['device'] = config['device'] + self.Prediction = Attention( + **config['Prediction']['params'] + ) + elif flow['Pred'] == 'Attnv2': + config['Prediction']['params']['num_classes'] = config['num_class'] + config['Prediction']['params']['device'] = config['device'] + self.Prediction = AttentionV2( + **config['Prediction']['params'] + ) + elif flow['Pred'] == 'Multistage_Attn': + config['Prediction']['params']['num_classes'] = config['num_class'] + config['Prediction']['params']['device'] = config['device'] + self.Prediction = AttentionV2( + **config['Prediction']['params'] + ) + elif flow['Pred'] == 'TFM': + config['Prediction']['params']['num_classes'] = config['num_class'] + config['Prediction']['params']['device'] = config['device'] + self.Prediction = TransformerPrediction( + **config['Prediction']['params'] + ) + else: + raise ValueError('Prediction name is not suppported') + + def forward(self, contextual_feature, text, is_train=True, is_test=False, rtl_text=None): + beam_size = self.config.get('beam_size', 1) + + addition_outputs = {} + decoder_attn = None + + if self.flow['Pred'] == 'CTC': + prediction = self.Prediction(contextual_feature.contiguous()) + + elif self.flow['Pred'] in ['Attn', 'Attnv2']: + prediction, logits, decoder_attn = self.Prediction(beam_size, contextual_feature.contiguous(), text, is_train=is_train, + is_test=is_test, batch_max_length=self.config['batch_max_length']) + + elif self.flow['Pred'] == 'TFM': + prediction, logits = self.Prediction(beam_size, contextual_feature.contiguous(), text, is_test) + self.Prediction.reset_beam() + + return prediction, logits, decoder_attn, addition_outputs diff --git a/HybridViT/recognizers/build_seq.py b/HybridViT/recognizers/build_seq.py new file mode 100644 index 0000000000000000000000000000000000000000..eddcc7c70c66d1bb7588ea41b2938dba465f4a8c --- /dev/null +++ b/HybridViT/recognizers/build_seq.py @@ -0,0 +1,60 @@ +import torch.nn as nn +from ..module.component.seq_modeling import BidirectionalLSTM, create_vit_modeling +from ..module.component.seq_modeling.bilstm import BiLSTM_Seq_Modeling +from ..module.component.common import GatedSum +from ..module.component.common import PositionalEncoding2D, PositionalEncoding1D + + +class SeqModelingBuilder(nn.Module): + def __init__(self, flow: dict, config, FeatureExtraction_output): + super().__init__() + self.config = config + self.flow = flow + + if flow['Seq'] == 'BiLSTM': + hidden_size = config['SequenceModeling']['params']['hidden_size'] + use_pos_enc = config['SequenceModeling']['params'].get('pos_enc', False) + + if use_pos_enc: + self.image_positional_encoder = PositionalEncoding1D(hidden_size) + self.gated = GatedSum(hidden_size) + + self.SequenceModeling = nn.Sequential( + BidirectionalLSTM(FeatureExtraction_output, hidden_size, hidden_size), + BidirectionalLSTM(hidden_size, hidden_size, hidden_size)) + + self.SequenceModeling_output = hidden_size + + elif flow['Seq'] == 'BiLSTM_3L': + hidden_size = config['SequenceModeling']['params']['hidden_size'] + self.SequenceModeling = BiLSTM_Seq_Modeling(3, FeatureExtraction_output, hidden_size, hidden_size) + + self.SequenceModeling_output = hidden_size + + elif flow['Seq'] == 'ViT': + assert config['max_dimension'] is not None, "ViT encoder require exact height or max height and max width" + self.SequenceModeling = create_vit_modeling(config) + else: + print('No SequenceModeling module specified') + if flow['Pred'] == 'TFM': + self.image_positional_encoder = PositionalEncoding2D(FeatureExtraction_output) + + self.SequenceModeling_output = FeatureExtraction_output + + def forward(self, visual_feature, *args, **kwargs): + output_shape = None + pad_info = None + + if self.flow['Seq'] in ['BiLSTM', 'BiLSTM_3L']: + contextual_feature = self.SequenceModeling(visual_feature) + + if hasattr(self, 'image_positional_encoder'): + assert len(contextual_feature.shape) == 3 + contextual_feature_1 = self.image_positional_encoder(visual_feature.permute(1, 0, 2)) + contextual_feature_1 = contextual_feature_1.permute(1, 0, 2) + contextual_feature = self.gated(contextual_feature_1, contextual_feature) + + elif self.flow['Seq'] == 'ViT': + contextual_feature, pad_info, _ = self.SequenceModeling(visual_feature) + + return contextual_feature, output_shape, pad_info \ No newline at end of file diff --git a/HybridViT/resizer.py b/HybridViT/resizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ScanSSD/IOU_lib/BoundingBox.py b/ScanSSD/IOU_lib/BoundingBox.py new file mode 100644 index 0000000000000000000000000000000000000000..12b82673302031367c28f1fcd25e6f8250369d24 --- /dev/null +++ b/ScanSSD/IOU_lib/BoundingBox.py @@ -0,0 +1,164 @@ +from .iou_utils import * + +class BoundingBox: + def __init__(self, + imageName, + classId, + x, + y, + w, + h, + typeCoordinates=CoordinatesType.Absolute, + imgSize=None, + bbType=BBType.GroundTruth, + classConfidence=None, + format=BBFormat.XYWH): + """Constructor. + Args: + imageName: String representing the image name. + classId: String value representing class id. + x: Float value representing the X upper-left coordinate of the bounding box. + y: Float value representing the Y upper-left coordinate of the bounding box. + w: Float value representing the width bounding box. + h: Float value representing the height bounding box. + typeCoordinates: (optional) Enum (Relative or Absolute) represents if the bounding box + coordinates (x,y,w,h) are absolute or relative to size of the image. Default:'Absolute'. + imgSize: (optional) 2D vector (width, height)=>(int, int) represents the size of the + image of the bounding box. If typeCoordinates is 'Relative', imgSize is required. + bbType: (optional) Enum (Groundtruth or Detection) identifies if the bounding box + represents a ground truth or a detection. If it is a detection, the classConfidence has + to be informed. + classConfidence: (optional) Float value representing the confidence of the detected + class. If detectionType is Detection, classConfidence needs to be informed. + format: (optional) Enum (BBFormat.XYWH or BBFormat.XYX2Y2) indicating the format of the + coordinates of the bounding boxes. BBFormat.XYWH: + BBFormat.XYX2Y2: . + """ + self._imageName = imageName + self._typeCoordinates = typeCoordinates + if typeCoordinates == CoordinatesType.Relative and imgSize is None: + raise IOError( + 'Parameter \'imgSize\' is required. It is necessary to inform the image size.') + if bbType == BBType.Detected and classConfidence is None: + raise IOError( + 'For bbType=\'Detection\', it is necessary to inform the classConfidence value.') + # if classConfidence != None and (classConfidence < 0 or classConfidence > 1): + # raise IOError('classConfidence value must be a real value between 0 and 1. Value: %f' % + # classConfidence) + + self._classConfidence = classConfidence + self._bbType = bbType + self._classId = classId + self._format = format + + # If relative coordinates, convert to absolute values + # For relative coords: (x,y,w,h)=(X_center/img_width , Y_center/img_height) + if (typeCoordinates == CoordinatesType.Relative): + (self._x, self._y, self._w, self._h) = convertToAbsoluteValues(imgSize, (x, y, w, h)) + self._width_img = imgSize[0] + self._height_img = imgSize[1] + if format == BBFormat.XYWH: + self._x2 = self._w + self._y2 = self._h + self._w = self._x2 - self._x + self._h = self._y2 - self._y + else: + raise IOError( + 'For relative coordinates, the format must be XYWH (x,y,width,height)') + # For absolute coords: (x,y,w,h)=real bb coords + else: + self._x = x + self._y = y + if format == BBFormat.XYWH: + self._w = w + self._h = h + self._x2 = self._x + self._w + self._y2 = self._y + self._h + else: # format == BBFormat.XYX2Y2: . + self._x2 = w + self._y2 = h + self._w = self._x2 - self._x + self._h = self._y2 - self._y + if imgSize is None: + self._width_img = None + self._height_img = None + else: + self._width_img = imgSize[0] + self._height_img = imgSize[1] + + def __str__(self): + return "{} {} {} {} {}".format(self.getImageName(), self._x, self._y, self._w, self._h) + + def getAbsoluteBoundingBox(self, format=BBFormat.XYWH): + if format == BBFormat.XYWH: + return (self._x, self._y, self._w, self._h) + elif format == BBFormat.XYX2Y2: + return (self._x, self._y, self._x2, self._y2) + + def getRelativeBoundingBox(self, imgSize=None): + if imgSize is None and self._width_img is None and self._height_img is None: + raise IOError( + 'Parameter \'imgSize\' is required. It is necessary to inform the image size.') + if imgSize is None: + return convertToRelativeValues((imgSize[0], imgSize[1]), + (self._x, self._y, self._w, self._h)) + else: + return convertToRelativeValues((self._width_img, self._height_img), + (self._x, self._y, self._w, self._h)) + + def getImageName(self): + return self._imageName + + def getConfidence(self): + return self._classConfidence + + def getFormat(self): + return self._format + + def getClassId(self): + return self._classId + + def getImageSize(self): + return (self._width_img, self._height_img) + + def getCoordinatesType(self): + return self._typeCoordinates + + def getBBType(self): + return self._bbType + + @staticmethod + def compare(det1, det2): + det1BB = det1.getAbsoluteBoundingBox() + det1ImgSize = det1.getImageSize() + det2BB = det2.getAbsoluteBoundingBox() + det2ImgSize = det2.getImageSize() + + if det1.getClassId() == det2.getClassId() and \ + det1.classConfidence == det2.classConfidenc() and \ + det1BB[0] == det2BB[0] and \ + det1BB[1] == det2BB[1] and \ + det1BB[2] == det2BB[2] and \ + det1BB[3] == det2BB[3] and \ + det1ImgSize[0] == det1ImgSize[0] and \ + det2ImgSize[1] == det2ImgSize[1]: + return True + return False + + @staticmethod + def clone(boundingBox): + absBB = boundingBox.getAbsoluteBoundingBox(format=BBFormat.XYWH) + # return (self._x,self._y,self._x2,self._y2) + newBoundingBox = BoundingBox( + boundingBox.getImageName(), + boundingBox.getClassId(), + absBB[0], + absBB[1], + absBB[2], + absBB[3], + typeCoordinates=boundingBox.getCoordinatesType(), + imgSize=boundingBox.getImageSize(), + bbType=boundingBox.getBBType(), + classConfidence=boundingBox.getConfidence(), + format=BBFormat.XYWH) + return newBoundingBox diff --git a/ScanSSD/IOU_lib/Evaluator.py b/ScanSSD/IOU_lib/Evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..d16aecde1e09a382b1fa29eaa0bb5e29bec2a2d8 --- /dev/null +++ b/ScanSSD/IOU_lib/Evaluator.py @@ -0,0 +1,87 @@ +########################################################################################### +# # +# Evaluator class: Implements the most popular metrics for object detection # +# # +# Developed by: Rafael Padilla (rafael.padilla@smt.ufrj.br) # +# SMT - Signal Multimedia and Telecommunications Lab # +# COPPE - Universidade Federal do Rio de Janeiro # +# Last modification: Oct 9th 2018 # +########################################################################################### + +import os +import sys +from collections import Counter + +import matplotlib.pyplot as plt +import numpy as np + +from .BoundingBox import * +from .iou_utils import * + + +class Evaluator: + + # For each detections, calculate IOU with reference + @staticmethod + def _getAllIOUs(reference, detections): + ret = [] + bbReference = reference.getAbsoluteBoundingBox(BBFormat.XYX2Y2) + # img = np.zeros((200,200,3), np.uint8) + for d in detections: + bb = d.getAbsoluteBoundingBox(BBFormat.XYX2Y2) + iou = Evaluator.iou(bbReference, bb) + # Show blank image with the bounding boxes + # img = add_bb_into_image(img, d, color=(255,0,0), thickness=2, label=None) + # img = add_bb_into_image(img, reference, color=(0,255,0), thickness=2, label=None) + ret.append((iou, reference, d)) # iou, reference, detection + # cv2.imshow("comparing",img) + # cv2.waitKey(0) + # cv2.destroyWindow("comparing") + return sorted(ret, key=lambda i: i[0], reverse=True) # sort by iou (from highest to lowest) + + @staticmethod + def iou(boxA, boxB): + # if boxes dont intersect + if Evaluator._boxesIntersect(boxA, boxB) is False: + return 0 + interArea = Evaluator._getIntersectionArea(boxA, boxB) + union = Evaluator._getUnionAreas(boxA, boxB, interArea=interArea) + # intersection over union + iou = interArea / union + assert iou >= 0 + return iou + + # boxA = (Ax1,Ay1,Ax2,Ay2) + # boxB = (Bx1,By1,Bx2,By2) + @staticmethod + def _boxesIntersect(boxA, boxB): + if boxA[0] > boxB[2]: + return False # boxA is right of boxB + if boxB[0] > boxA[2]: + return False # boxA is left of boxB + if boxA[3] < boxB[1]: + return False # boxA is above boxB + if boxA[1] > boxB[3]: + return False # boxA is below boxB + return True + + @staticmethod + def _getIntersectionArea(boxA, boxB): + xA = max(boxA[0], boxB[0]) + yA = max(boxA[1], boxB[1]) + xB = min(boxA[2], boxB[2]) + yB = min(boxA[3], boxB[3]) + # intersection area + return (xB - xA + 1) * (yB - yA + 1) + + @staticmethod + def _getUnionAreas(boxA, boxB, interArea=None): + area_A = Evaluator._getArea(boxA) + area_B = Evaluator._getArea(boxB) + if interArea is None: + interArea = Evaluator._getIntersectionArea(boxA, boxB) + return float(area_A + area_B - interArea) + + @staticmethod + def _getArea(box): + return (box[2] - box[0] + 1) * (box[3] - box[1] + 1) diff --git a/ScanSSD/IOU_lib/IOUevaluater.py b/ScanSSD/IOU_lib/IOUevaluater.py new file mode 100644 index 0000000000000000000000000000000000000000..16a1af36412ddd7c83397edec3e5e121fcefbecf --- /dev/null +++ b/ScanSSD/IOU_lib/IOUevaluater.py @@ -0,0 +1,433 @@ +from zipfile import ZipFile +import os +from .Evaluator import * +from utils import * +import copy +import argparse +import sys +import ntpath +#import cStringIO +from io import BytesIO +import shutil + + +def read_file(filename, bboxes, flag): + ''' + Parses the input .csv file into map where key as page number and value as a list of bounding box objects + corresponding to each math region in the file. + :param filename: .csv file containing math regions + :param bboxes: Map> + :return: + ''' + fh1 = open(filename, "r") + prev_page = -1 + counter = 1 + for line in fh1: + line = line.replace("\n", "") + if line.replace(' ', '') == '': + continue + splitLine = line.split(",") + idClass = float(splitLine[0]) + if prev_page == -1: + prev_page = idClass + else: + if idClass != prev_page: + counter = 1 + prev_page = idClass + x = float(splitLine[1]) + y = float(splitLine[2]) + x2 = float(splitLine[3]) + y2 = float(splitLine[4]) + bb = BoundingBox( + flag+"_"+str(counter), + 1, + x, + y, + x2, + y2, + CoordinatesType.Absolute, (200, 200), + BBType.GroundTruth, + format=BBFormat.XYX2Y2) + counter += 1 + #print(counter) + if idClass not in bboxes: + bboxes[idClass] = [] + bboxes[idClass].append(bb) + + fh1.close() + + +def extract_zipfile(zip_filename, target_dir): + ''' + Extract zip file into the target directory + :param zip_filename: full-file-path of the zip-file + :param target_dir: target-dir to extract contents of zip-file + :return: + ''' + with ZipFile(zip_filename, 'r') as zip: + # extracting all the files + print('Extracting all the files now...') + zip.extractall(target_dir) + print('Done!') + + +def create_doc_bboxes_map(dir_path,flag): + ''' + Reads all files recursively in directory path and and returns a map containing bboxes for each page in each math + file in directory. + :param dir_path: full directory path containing math files + :return: Map>> + ''' + pdf_bboxes_map = {} + + for filename in os.listdir(dir_path): + full_filepath = os.path.join(dir_path, filename) + filename_key = os.path.splitext(os.path.basename(full_filepath))[0] + #print(full_filepath) + if (full_filepath.startswith(".")) or (not (full_filepath.endswith(".csv") or full_filepath.endswith(".math"))): + continue + bboxes_map = {} + + if os.path.isdir(full_filepath): + continue + + try: + read_file(full_filepath, bboxes_map,flag) + except Exception as e: + print('exception occurred in reading file',full_filepath, str(e)) + + #if len(bboxes_map)==0: + # raise ValueError("Empty ground truths file or not in valid format") + pdf_bboxes_map[filename_key] = copy.deepcopy(bboxes_map) + + return pdf_bboxes_map + +def unique_values(input_dict): + #return ground truth boxes that have same det boxes + pred_list=[] + repair_keys=[] + for value in input_dict.values(): + if value[1] in pred_list: #preds.append(value) + gts=[k for k,v in input_dict.items() if v[1] == value[1]] + #print('pair length',len(gts)) + repair_keys.append(gts) + pred_list.append(value[1]) + + return repair_keys + +def generate_validpairs(pairs): + newpairs=[] + for pair in pairs: + if len(pair)>2: + for i in range(len(pair)-1): + newpair=(pair[i],pair[i+1]) + if newpair not in newpairs:newpairs.append(newpair) + + elif pair not in newpairs: newpairs.append(pair) + return newpairs + +def fix_preds(input_dict,keyPairs,thre): + + validPairs=generate_validpairs(keyPairs) + + for pair in validPairs: + #check if both pair exists" + if pair[0] not in list(input_dict.keys()) or pair[1] not in list(input_dict.keys()): + continue + val0=input_dict[pair[0]][0] + val1=input_dict[pair[1]][0] + if val0>=val1: #change prediction for second pair + values=input_dict[pair[1]] + newprob=values[2][1] + if newprobval0: #change prediction for first pair + values=input_dict[pair[0]] + newprob=values[2][1] + if newprob> for ground truth bboxes + :param det_page_bboxes_map: Map> for detection bboxes + :return: + ''' + evaluator = Evaluator() + + correct_pred_coarse=0 + correct_pred_fine=0 + + pdf_gt_boxes=0 + pdf_det_boxes=0 + + coarse_keys = {} + fine_keys = {} + + for page_num in gt_page_bboxes_map: + if page_num not in det_page_bboxes_map: + print('Detections not found for page', str(page_num + 1), ' in', pdf_name) + continue + gt_boxes = gt_page_bboxes_map[page_num] + det_boxes = det_page_bboxes_map[page_num] + + pdf_gt_boxes+=len(gt_boxes) + pdf_det_boxes+=len(det_boxes) + + pred_dict={} + for gt_box in gt_boxes: + ious = evaluator._getAllIOUs(gt_box, det_boxes) + preds=[] + labels=[] + for i in range(len(ious)): + preds.append(round(ious[i][0],2)) + labels.append(ious[i][2].getImageName()) + + pred_dict[gt_box.getImageName()]=preds[0],labels[0],preds,labels + + coarse,coarse_dict=count_true_box(copy.deepcopy(pred_dict),0.5) + fine,fine_dict=count_true_box(copy.deepcopy(pred_dict),0.75) + + coarse_keys[page_num] = coarse_dict.keys() + fine_keys[page_num] = fine_dict.keys() + + #count correct preds for coarse 0.5 and fine 0.75 in one page + correct_pred_coarse= correct_pred_coarse+coarse + correct_pred_fine= correct_pred_fine+fine + #write iou per page + if outdir: + out_file = open(os.path.join(outdir,pdf_name.split(".csv")[0]+"_"+str(page_num)+"_eval.txt"), "w") + out_file.write('#page num '+str(page_num)+", gt_box:"+str(len(gt_boxes))+ + ", pred_box:"+str(len(det_boxes))+"\n") + out_file.write('\n') + out_file.write('#COARSE DETECTION (iou>0.5):\n#number of correct prediction:'+ str(coarse)+ '\n#correctly detected:'+ + str(list(coarse_dict.keys()))+'\n') + out_file.write('\n') + out_file.write('#FINE DETECTION (iou>0.75):\n#number of correct prediction:'+ str(fine)+ '\n#correctly detected:'+ + str(list(fine_dict.keys()))+'\n') + out_file.write('\n') + out_file.write('#Sorted IOU scores for each GT box:\n') + for gt_box in gt_boxes: + ious = evaluator._getAllIOUs(gt_box, det_boxes) + out_file.write(gt_box.getImageName()+",") + for i in range(len(ious)-1): + out_file.write("("+str(round(ious[i][0],2))+" "+ str(ious[i][2].getImageName())+"),") + out_file.write( "("+str(round(ious[-1][0],2))+" "+ str(ious[-1][2].getImageName())+")\n" ) + out_file.close() + + return correct_pred_coarse, correct_pred_fine, pdf_gt_boxes, pdf_det_boxes, coarse_keys, fine_keys + +def count_box(input_dict): + count=0 + for pdf in input_dict.values(): + for page in pdf.values(): + count+=len(page) + + return count + +# Zip every uploading files +def archive_iou_txt(username, task_id, sub_id,userpath): + + inputdir=os.path.join(userpath,'iouEval_stats') + + if not os.path.exists(inputdir): + print('No txt file is generated for IOU evaluation') + pass + + dest_uploader = 'IOU_stats_archive' + dest_uploader = os.path.join(userpath, dest_uploader) + + if not os.path.exists(dest_uploader): + os.makedirs(dest_uploader) + + zip_file_name = '/' + task_id + '_' + sub_id + shutil.make_archive(dest_uploader + zip_file_name, 'zip', inputdir) + + # return '/media/' + dest_uploader + +def write_html(gtFile,resultsFile,info,scores,destFile): + + destFile.write('') + destFile.write('') + destFile.write('') + #writeCSS(destFile) + destFile.write ("
CROHME 2019

Formula Detection Results ( TASK 3 )


") + destFile.write("Submitted Files
  • Output: "+ ntpath.basename(resultsFile) +"
  • ") + destFile.write ("
  • Ground-truth: " + ntpath.basename(gtFile) + "
") + if info['allGTbox'] == 0: + sys.stderr.write("Error : no sample in this GT list !\n") + exit(-1) + #all detection and gt boxes + destFile.write ("

Number of ground truth bounding boxes: " + str(info['allGTbox']) + "
Number of detected bounding boxes: " + str(info['allDet'])) + destFile.write ("


") + #coarse results + destFile.write ("

**** Coarse Detection Results (IOU>0.5) ****
") + destFile.write ("

  • "+str(scores['coarse_f']) + " F-score
  • ") + destFile.write ("
  • "+str(scores['coarse_pre']) + " Precision
  • ") + destFile.write ("
  • "+str(scores['coarse_rec']) + " Recall
") + destFile.write ("" + str(info['correctDet_c']) + " Number of correctly detected bounding boxes

") + destFile.write ("
") + #fine results + destFile.write ("

**** Fine Detection Results (IOU>0.75) ****
") + destFile.write ("

  • "+str(scores['fine_f']) + " F-score
  • ") + destFile.write ("
  • "+str(scores['fine_pre']) + " Precision
  • ") + destFile.write ("
  • "+str(scores['fine_rec']) + " Recall
") + destFile.write ("" + str(info['correctDet_f']) + " Number of correctly detected bounding boxes

") + destFile.write ("
") + destFile.write('') + destFile.write('') + +def pre_rec_calculate(count): + + if count['allDet']==0: + print ('No detection boxes found') + scores={'fine_f':0,'coarse_f':0} + else: + pre_f=count['correctDet_f']/float(count['allDet']) + recall_f=count['correctDet_f']/float(count['allGTbox']) + if pre_f==0 and recall_f ==0: + f_f=0 + else: + f_f=2*(pre_f*recall_f)/float(pre_f+recall_f) + + pre_c=count['correctDet_c']/float(count['allDet']) + recall_c=count['correctDet_c']/float(count['allGTbox']) + if pre_c==0 and recall_c==0: + f_c=0 + else: + f_c=2*(pre_c*recall_c)/float(pre_c+recall_c) + print('') + print('**** coarse result : threshold: 0.5 *****') + print(' f =',f_c,' precision =',pre_c,' recall =',recall_c) + print('') + print('**** fine result : threshold: 0.75 *****') + print(' f =',f_f,' precision =',pre_f,' recall =',recall_f) + + scores={'fine_f':round(f_f,4),'fine_pre':round(pre_f,4),'fine_rec':round(recall_f,4), + 'coarse_f':round(f_c,4),'coarse_pre':round(pre_c,4),'coarse_rec':round(recall_c,4)} + return scores + +def IOUeval(ground_truth, detections, outdir=None): #, + + keys=['allGTbox','correctDet_c','correctDet_f','allDet'] + info=dict.fromkeys(keys,0) + + gt_file_name = ground_truth + det_file_name = detections + + #TODO : Mahshad change it to user directory + if outdir: + #outdir='IOU_eval_stats' + if os.path.exists(outdir): + shutil.rmtree(outdir) + os.makedirs(outdir) + + gt_pdfs_bboxes_map = create_doc_bboxes_map(gt_file_name,'gt') + det_pdfs_bboxes_map = create_doc_bboxes_map(det_file_name,'det') + #count boxes + all_gtbox=count_box(gt_pdfs_bboxes_map) + all_detbox=count_box(det_pdfs_bboxes_map) + + + info['allGTbox']=all_gtbox + info['allDet']=all_detbox + + pdf_gt_bbs = 0 + pdf_dt_bbs = 0 + pdf_info = {} + pdf_calcs = {} + + detailed_detections = {} + + for pdf_name in gt_pdfs_bboxes_map: + if pdf_name not in det_pdfs_bboxes_map: + print('Detections not found for ',pdf_name) + continue + + det_page_bboxes_map = det_pdfs_bboxes_map[pdf_name] + gt_page_bboxes_map = gt_pdfs_bboxes_map[pdf_name] + + coarse_true_det,fine_true_det,pdf_gt_boxes,pdf_det_boxes,coarse_keys,fine_keys=\ + IoU_page_bboxes(gt_page_bboxes_map, det_page_bboxes_map, pdf_name,outdir) + info['correctDet_c']=info['correctDet_c']+coarse_true_det + info['correctDet_f']=info['correctDet_f']+fine_true_det + + pdf_info['correctDet_c']=coarse_true_det + pdf_info['correctDet_f']=fine_true_det + pdf_info['allGTbox']=pdf_gt_boxes + pdf_info['allDet']=pdf_det_boxes + + print('For pdf: ', pdf_name) + pdf_calcs[pdf_name]=pre_rec_calculate(pdf_info) + detailed_detections[pdf_name] = [coarse_keys, fine_keys] + #print('Pdf score:',pdf_name, " --> ", pre_rec_calculate(pdf_info)) + + print('\n') + print(info) + scores=pre_rec_calculate(info) + + print('\n PDF Level \n') + #print(pdf_calcs) + + #{'fine_f': 0.7843, 'fine_pre': 0.7774, 'fine_rec': 0.7914, 'coarse_f': 0.902, 'coarse_pre': 0.894, 'coarse_rec': 0.9101} + for pdf_name in pdf_calcs: + print(pdf_name,'\t', pdf_calcs[pdf_name]['coarse_f'],'\t',pdf_calcs[pdf_name]['fine_f']) + + #return corase and fine F-scores + return scores['coarse_f'],scores['fine_f'], detailed_detections + +if __name__=='__main__': + + parser = argparse.ArgumentParser() + parser.add_argument("--detections", type=str, required=True, help="detections file path") + parser.add_argument("--ground_truth", type=str, required=True, help="ground_truth file path") + args = parser.parse_args() + + gt_file_name = args.ground_truth + det_file_name = args.detections + + + c_f,f_f=IOUeval(gt_file_name,det_file_name,outdir='IOU_scores_pages/') + + + diff --git a/ScanSSD/IOU_lib/__init__.py b/ScanSSD/IOU_lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ScanSSD/IOU_lib/iou_utils.py b/ScanSSD/IOU_lib/iou_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..761a9a074ba3ddea58559d0066403d0c7dfbdb32 --- /dev/null +++ b/ScanSSD/IOU_lib/iou_utils.py @@ -0,0 +1,113 @@ +from enum import Enum + +import cv2 + + +class CoordinatesType(Enum): + """ + Class representing if the coordinates are relative to the + image size or are absolute values. + + Developed by: Rafael Padilla + Last modification: Apr 28 2018 + """ + Relative = 1 + Absolute = 2 + + +class BBType(Enum): + """ + Class representing if the bounding box is groundtruth or not. + + Developed by: Rafael Padilla + Last modification: May 24 2018 + """ + GroundTruth = 1 + Detected = 2 + + +class BBFormat(Enum): + """ + Class representing the format of a bounding box. + It can be (X,Y,width,height) => XYWH + or (X1,Y1,X2,Y2) => XYX2Y2 + + Developed by: Rafael Padilla + Last modification: May 24 2018 + """ + XYWH = 1 + XYX2Y2 = 2 + + +# size => (width, height) of the image +# box => (X1, X2, Y1, Y2) of the bounding box +def convertToRelativeValues(size, box): + dw = 1. / (size[0]) + dh = 1. / (size[1]) + cx = (box[1] + box[0]) / 2.0 + cy = (box[3] + box[2]) / 2.0 + w = box[1] - box[0] + h = box[3] - box[2] + x = cx * dw + y = cy * dh + w = w * dw + h = h * dh + # x,y => (bounding_box_center)/width_of_the_image + # w => bounding_box_width / width_of_the_image + # h => bounding_box_height / height_of_the_image + return (x, y, w, h) + + +# size => (width, height) of the image +# box => (centerX, centerY, w, h) of the bounding box relative to the image +def convertToAbsoluteValues(size, box): + # w_box = round(size[0] * box[2]) + # h_box = round(size[1] * box[3]) + xIn = round(((2 * float(box[0]) - float(box[2])) * size[0] / 2)) + yIn = round(((2 * float(box[1]) - float(box[3])) * size[1] / 2)) + xEnd = xIn + round(float(box[2]) * size[0]) + yEnd = yIn + round(float(box[3]) * size[1]) + if xIn < 0: + xIn = 0 + if yIn < 0: + yIn = 0 + if xEnd >= size[0]: + xEnd = size[0] - 1 + if yEnd >= size[1]: + yEnd = size[1] - 1 + return (xIn, yIn, xEnd, yEnd) + + +def add_bb_into_image(image, bb, color=(255, 0, 0), thickness=2, label=None): + r = int(color[0]) + g = int(color[1]) + b = int(color[2]) + + font = cv2.FONT_HERSHEY_SIMPLEX + fontScale = 0.5 + fontThickness = 1 + + x1, y1, x2, y2 = bb.getAbsoluteBoundingBox(BBFormat.XYX2Y2) + x1 = int(x1) + y1 = int(y1) + x2 = int(x2) + y2 = int(y2) + cv2.rectangle(image, (x1, y1), (x2, y2), (b, g, r), thickness) + # Add label + if label is not None: + # Get size of the text box + (tw, th) = cv2.getTextSize(label, font, fontScale, fontThickness)[0] + # Top-left coord of the textbox + (xin_bb, yin_bb) = (x1 + thickness, y1 - th + int(12.5 * fontScale)) + # Checking position of the text top-left (outside or inside the bb) + if yin_bb - th <= 0: # if outside the image + yin_bb = y1 + th # put it inside the bb + r_Xin = x1 - int(thickness / 2) + r_Yin = y1 - th - int(thickness / 2) + # Draw filled rectangle to put the text in it + cv2.rectangle(image, (r_Xin, r_Yin - thickness), + (r_Xin + tw + thickness * 3, r_Yin + th + int(12.5 * fontScale)), (b, g, r), + -1) + cv2.putText(image, label, (xin_bb, yin_bb), font, fontScale, (0, 0, 0), fontThickness, + cv2.LINE_AA) + return image diff --git a/ScanSSD/README.md b/ScanSSD/README.md new file mode 100644 index 0000000000000000000000000000000000000000..548fd1656229bbc5531436d78c85336004db645e --- /dev/null +++ b/ScanSSD/README.md @@ -0,0 +1,11 @@ +# ScanSSD: Scanning Single Shot Detector for Math in Document Images + + +A [PyTorch](http://pytorch.org/) implementation of ScanSSD [Scanning Single Shot MultiBox Detector](https://paragmali.me/scanning-single-shot-detector-for-math-in-document-images/) by [**Parag Mali**](https://github.com/MaliParag/). It was developed using SSD implementation by [**Max deGroot**](https://github.com/amdegroot). + +All credit goes to the authors of the paper and the original implementation. + +--- + +I have made some changes to the original implementation to make it work with the latest version of PyTorch and Python. +I have also removed some unnecessary files, in particular the ones related to dataset. \ No newline at end of file diff --git a/ScanSSD/data/__init__.py b/ScanSSD/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4dea4d182b7f481b141ced1de9806ba2b58f1932 --- /dev/null +++ b/ScanSSD/data/__init__.py @@ -0,0 +1,47 @@ +#from .gtdb import GTDB_CLASSES, GTDB_ROOT, GTDBAnnotationTransform, GTDBDetection +from .gtdb_new import GTDB_CLASSES, GTDB_ROOT, GTDBAnnotationTransform, GTDBDetection +from .config import * +import torch +import cv2 +import numpy as np + +def detection_collate(batch): + """Custom collate fn for dealing with batches of images that have a different + number of associated object annotations (bounding boxes). + + Arguments: + batch: (tuple) A tuple of tensor images and lists of annotations + + Return: + A tuple containing: + 1) (tensor) batch of images stacked on their 0 dim + 2) (list of tensors) annotations for a given image are stacked on + 0 dim + """ + targets = [] + imgs = [] + ids = [] + + for sample in batch: + imgs.append(sample[0]) + targets.append(torch.FloatTensor(sample[1])) + ids.append(sample[2]) + + return torch.stack(imgs, 0), targets, ids + + +def base_transform(image, size, mean): + #print('Image size ', image.shape) + image = image.astype(np.float32) + x = cv2.resize(image, (size, size), interpolation=cv2.INTER_AREA).astype(np.float32) + x -= mean + return x + + +class BaseTransform: + def __init__(self, size, mean): + self.size = size + self.mean = np.array(mean, dtype=np.float32) + + def __call__(self, image, boxes=None, labels=None): + return base_transform(image, self.size, self.mean), boxes, labels diff --git a/ScanSSD/data/config.py b/ScanSSD/data/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8a840a169a7221b074d2b103a9b5e4e384f0c869 --- /dev/null +++ b/ScanSSD/data/config.py @@ -0,0 +1,158 @@ +# config.py +import os.path + +# gets home dir cross platform +HOME = os.path.expanduser("~") + +# for making bounding boxes pretty +COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128), + (0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128)) + +MEANS = (246, 246, 246) + +exp_cfg = { + + 'gtdb': { + 'num_classes': 2, + 'lr_steps': (80000, 100000, 120000), + + 'max_iter': 120000, + 'feature_maps': [64, 32, 16, 8, 4, 2, 1], + 'min_dim': 512, + 'steps': [8, 16, 32, 64, 128, 256, 512], + 'min_sizes': [8.00, 76.8, 153.6, 230.4, 307.2, 384.0, 460.8], + 'max_sizes': [76.8, 153.6, 230.4, 307.2, 384.0, 460.8, 537.6], + 'aspect_ratios': [[2, 3, 5], [2, 3, 5, 7], [2, 3, 5, 7], [2, 3], [2, 3], [2], [2]], + + 'variance': [0.1, 0.2], + 'clip': True, + 'name': 'GTDB', + + 'is_vertical_prior_boxes_enabled': True, + + 'mbox': { + '512': [8, 10, 10, 6, 6, 4, 4], + #'512': [5, 6, 6, 4, 4, 3, 3], + '300': [8, 10, 10, 6, 4, 4], # number of boxes per feature map location + }, + 'extras': { + '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256], + '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], + } + }, + + 'math_gtdb_512': { + + 'num_classes': 2, + 'lr_steps': (80000, 100000, 120000), + 'max_iter': 240000, + 'feature_maps': [64, 32, 16, 8, 4, 2, 1], + 'min_dim': 512, + 'steps': [8, 16, 32, 64, 128, 256, 512], + 'min_sizes': [8.00, 76.8, 153.6, 230.4, 307.2, 384.0, 460.8], + 'max_sizes': [76.8, 153.6, 230.4, 307.2, 384.0, 460.8, 537.6], + 'aspect_ratios': [[2, 3, 5, 7, 10], [2, 3, 5, 7, 10], [2, 3, 5, 7, 10], [2, 3, 5, 7, 10], + [2, 3, 5, 7, 10], [2, 3, 5, 7, 10], [2, 3, 5, 7, 10]], + 'variance': [0.1, 0.2], + 'clip': True, + 'name': 'math_gtdb_512', + 'is_vertical_prior_boxes_enabled': True, + 'mbox': { + '512': [12,12,12,12,12,12,12], + }, + 'extras': { + '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256], + } + }, + + 'ssd300': { + 'num_classes': 2, + 'lr_steps': (80000, 100000, 120000), + 'max_iter': 132000, + 'feature_maps': [38, 19, 10, 5, 3, 1], + 'min_dim': 300, + 'steps': [8, 16, 32, 64, 100, 300], + 'min_sizes': [30, 60, 111, 162, 213, 264], + 'max_sizes': [60, 111, 162, 213, 264, 315], + 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]], + 'variance': [0.1, 0.2], + 'clip': True, + 'name': 'ssd300', + 'is_vertical_prior_boxes_enabled': True, + 'mbox': { + '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location + }, + 'extras': { + '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], + } + }, + + 'ssd512': { + 'num_classes': 2, + 'lr_steps': (80000, 100000, 120000), + 'max_iter': 132000, + 'feature_maps': [64, 32, 16, 8, 4, 2, 1], + 'min_dim': 512, + 'steps': [8, 16, 32, 64, 128, 256, 512], + 'min_sizes': [35.84, 76.8, 153.6, 230.4, 307.2, 384.0, 460.8], + 'max_sizes': [76.8, 153.6, 230.4, 307.2, 384.0, 460.8, 537.6], + 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2,3], [2], [2]], + 'variance': [0.1, 0.2], + 'clip': True, + 'name': 'ssd512', + 'is_vertical_prior_boxes_enabled': True, + 'mbox': { + '512': [4,6,6,6,6,4,4], + }, + 'extras': { + '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256], + } + }, + + 'aspect512': { + 'num_classes': 2, + 'lr_steps': (80000, 100000, 120000), + 'max_iter': 132000, + 'feature_maps': [64, 32, 16, 8, 4, 2, 1], + 'min_dim': 512, + 'steps': [8, 16, 32, 64, 128, 256, 512], + 'min_sizes': [35.84, 76.8, 153.6, 230.4, 307.2, 384.0, 460.8], + 'max_sizes': [76.8, 153.6, 230.4, 307.2, 384.0, 460.8, 537.6], + 'aspect_ratios': [[2, 3, 5, 7, 10], [2, 3, 5, 7, 10], [2, 3, 5, 7, 10], [2, 3, 5, 7, 10], + [2, 3, 5, 7, 10], [2, 3, 5, 7, 10], [2, 3, 5, 7, 10]], + 'variance': [0.1, 0.2], + 'clip': True, + 'name': 'ssd512', + 'is_vertical_prior_boxes_enabled': True, + 'mbox': { + '512': [12,12,12,12,12,12,12], + }, + 'extras': { + '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256], + } + }, + + 'hboxes512': { + 'num_classes': 2, + 'lr_steps': (80000, 100000, 120000), + 'max_iter': 132000, + 'feature_maps': [64, 32, 16, 8, 4, 2, 1], + 'min_dim': 512, + 'steps': [8, 16, 32, 64, 128, 256, 512], + 'min_sizes': [35.84, 76.8, 153.6, 230.4, 307.2, 384.0, 460.8], + 'max_sizes': [76.8, 153.6, 230.4, 307.2, 384.0, 460.8, 537.6], + 'aspect_ratios': [[2, 3, 5, 7, 10], [2, 3, 5, 7, 10], [2, 3, 5, 7, 10], [2, 3, 5, 7, 10], + [2, 3, 5, 7, 10], [2, 3, 5, 7, 10], [2, 3, 5, 7, 10]], + 'variance': [0.1, 0.2], + 'clip': True, + 'name': 'ssd512', + 'is_vertical_prior_boxes_enabled': False, + 'mbox': { + '512': [7,7,7,7,7,7,7], + }, + 'extras': { + '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256], + } + }, + +} \ No newline at end of file diff --git a/ScanSSD/data/gtdb_new.py b/ScanSSD/data/gtdb_new.py new file mode 100644 index 0000000000000000000000000000000000000000..b85e45270805e72da1f5a9b94f78cee4922da3ef --- /dev/null +++ b/ScanSSD/data/gtdb_new.py @@ -0,0 +1,284 @@ +""" +Author: Parag Mali +Data reader for the GTDB dataset +Uses sliding windows to generate sub-images +""" + +from .config import HOME +import os.path as osp +import sys +import torch +import torch.utils.data as data +import cv2 +import numpy as np +from ..gtdb import box_utils +from ..gtdb import feature_extractor +import copy +from ..utils import visualize as visualize + + +GTDB_CLASSES = ( # always index 0 is background + 'math') + +GTDB_ROOT = osp.join(HOME, "data/GTDB/") + + +class GTDBAnnotationTransform(object): + """Transforms a GTDB annotation into a Tensor of bbox coords and label index + Initilized with a dictionary lookup of classnames to indexes + + Arguments: + class_to_ind (dict, optional): dictionary lookup of classnames -> indexes + height (int): height + width (int): width + """ + + def __init__(self, class_to_ind=None): + pass + + def __call__(self, target, width, height): + """ + Arguments: + target (annotation) : the target annotations. This will be the list of bounding boxes + Returns: + a list containing lists of bounding boxes [bbox coords, class name] + """ + res = [] + + # read the annotations + for box in target: + res.append([box[0]/width, box[1]/height, box[2]/width, box[3]/height, 0]) + + return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] + + +class GTDBDetection(data.Dataset): + """GTDB Detection Dataset Object + + input is image, target is annotation + + Arguments: + root (string): filepath to GTDB folder. + image_set (string): imageset to use (eg. 'train', 'val', 'test') + transform (callable, optional): transformation to perform on the + input image + target_transform (callable, optional): transformation to perform on the + target `annotation` + (eg: take in caption string, return tensor of word indices) + dataset_name: `GTDB` + """ + + + def __init__(self, args, data_file, split='train', + transform=None, target_transform=GTDBAnnotationTransform(), + dataset_name='GTDB'): + + #split can be train, validate or test + + self.root = args.dataset_root + self.image_set = data_file + self.split = split + self.transform = transform + self.target_transform = target_transform + self.name = dataset_name + self.use_char_info = args.use_char_info + + # The stride to use for the windowing + self.stride = args.stride # 0.1 + self.window = args.window # 1200 + + self._annopath = osp.join('%s', 'annotations', '%s.pmath') + self._imgpath = osp.join('%s', 'images', '%s.png') + self._char_annopath = osp.join('%s', 'char_annotations', '%s.pchar') + + self.ids = list() + + for line in open(osp.join(self.root, self.image_set)): + self.ids.append((self.root, line.strip())) + + # initialize the training images and annotations + self.images = {} + self.math_ground_truth = {} + self.is_math = {} + # for each id store the image patch which has math + # This can be used as a box util + self.metadata = [] + + self.read_all_images() + self.read_gt_annotations() + self.generate_metadata() + + def generate_metadata(self): + + for id in self.ids: + math_patches = [] + + height, width, channels = self.images[id[1]].shape + current_page_boxes = self.math_ground_truth[id[1]] + n_horizontal = np.ceil(width / self.window) # 4 + n_vertical = np.ceil(height / self.window) # 5 + + h = np.arange(0, n_horizontal - 1 + self.stride, self.stride) + v = np.arange(0, n_vertical - 1 + self.stride, self.stride) + crop_size = self.window + + if (self.split == 'train' or self.split == 'validate') and self.is_math[id[1]]: + + for i in h: + for j in v: + x_l = int(np.round(crop_size * i)) + x_h = x_l + self.window + + y_l = int(np.round(crop_size * j)) + y_h = y_l + self.window + + # left, top, right, bottom + image_box = [x_l, y_l, x_h, y_h] + + current_page_boxes = copy.deepcopy(self.math_ground_truth[id[1]]) + + # if math intersects only consider the region which + # is part of the current bounding box + for box in current_page_boxes: + if box_utils.intersects(image_box, box): + # left, top, right, bottom + # y increases downwards + + # crop the boxes to fit into image region + box[0] = max(x_l, box[0]) + box[1] = max(y_l, box[1]) + box[2] = min(x_h, box[2]) + box[3] = min(y_h, box[3]) + + # # Translate to origin + box[0] = box[0] - x_l + box[2] = box[2] - x_l + + box[1] = box[1] - y_l + box[3] = box[3] - y_l + + if feature_extractor.width(box) > 0 and feature_extractor.height(box) > 0: + self.metadata.append([id[1], x_l, y_l]) + break + elif self.split=='test': + for i in h: + for j in v: + x_l = int(np.round(crop_size * i)) + y_l = int(np.round(crop_size * j)) + self.metadata.append([id[1], x_l, y_l]) + + def read_all_images(self): + + # This function reads all the images in the training dataset. + # GTDB dataset is small, loading all the images at once is not a problem + + for id in self.ids: + image = cv2.imread(self._imgpath % id, cv2.IMREAD_COLOR) + self.images[id[1]] = image + + def read_gt_annotations(self): + + # This function reads all the annotations for the training images + for id in self.ids: + if osp.exists(self._annopath % id): + gt_regions = np.genfromtxt(self._annopath % id, delimiter=',') + gt_regions = gt_regions.astype(int) + + # if there is only one entry convert it to correct form required + if len(gt_regions.shape) == 1: + gt_regions = gt_regions.reshape(1, -1) + + self.math_ground_truth[id[1]] = gt_regions + self.is_math[id[1]] = True + else: + self.math_ground_truth[id[1]] = np.array([-1,-1,-1,-1]).reshape(1,-1) + self.is_math[id[1]] = False + + def __getitem__(self, index): + im, gt, metadata = self.pull_item(index) + return im, gt, metadata + + def __len__(self): + return len(self.metadata) + + def gen_targets(self, index): + + metadata = self.metadata[index] + + x_l = metadata[1] + y_l = metadata[2] + x_h = x_l + self.window + y_h = y_l + self.window + + current_page_boxes = copy.deepcopy(self.math_ground_truth[metadata[0]]) + targets = [] + + image_box = [x_l, y_l, x_h, y_h] + + # if math intersects only consider the region which + # is part of the current bounding box + for box in current_page_boxes: + if box_utils.intersects(image_box, box): + # left, top, right, bottom + # y increases downwards + + # crop the boxes to fit into image region + box[0] = max(x_l, box[0]) + box[1] = max(y_l, box[1]) + box[2] = min(x_h, box[2]) + box[3] = min(y_h, box[3]) + + # # Translate to origin + box[0] = box[0] - x_l + box[2] = box[2] - x_l + + box[1] = box[1] - y_l + box[3] = box[3] - y_l + + if feature_extractor.width(box) > 0 and feature_extractor.height(box) > 0: + targets.append(box) + + # It is done only for testing, where we do not care about targets + # This avoids IndexError: too many indices for array + # TODO: refactor in future + if len(targets) == 0: + targets = [[-1,-1,-1,-1]] + + return targets + + def gen_image(self, index): + + metadata = self.metadata[index] + image = self.images[metadata[0]] + + x_l = metadata[1] + y_l = metadata[2] + x_h = x_l + min(self.window, image.shape[1]-x_l) + y_h = y_l + min(self.window, image.shape[0]-y_l) + + cropped_image = np.full((self.window, self.window, image.shape[2]), 255) + cropped_image[:y_h-y_l, :x_h-x_l, :] = image[y_l: y_h, x_l: x_h, :] + + return cropped_image + + def pull_item(self, index): + + metadata = self.metadata[index] + + target = self.gen_targets(index) + img = self.gen_image(index) + + height, width, channels = img.shape + + if self.target_transform is not None: + target = self.target_transform(target, width, height) + + if self.transform is not None: + target = np.array(target) + img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) + # to rgb + img = img[:, :, (2, 1, 0)] + target = np.hstack((boxes, np.expand_dims(labels, axis=1))) + + + return torch.from_numpy(img).permute(2, 0, 1), target, metadata diff --git a/ScanSSD/detect_flow.py b/ScanSSD/detect_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..825a75a08a148b8863a4889a44d599844135ae18 --- /dev/null +++ b/ScanSSD/detect_flow.py @@ -0,0 +1,118 @@ +# Fix relative import issues +from collections import OrderedDict +import cv2 +import numpy as np +import torch +from .ssd import build_ssd +from .data import config + +def _img_to_tensor (image): + rimg = cv2.resize(image, (512, 512), + interpolation = cv2.INTER_AREA).astype(np.float32) + rimg -= np.array((246, 246, 246), dtype=np.float32) + rimg = rimg[:, :, (2, 1, 0)] + return torch.from_numpy(rimg).permute(2, 0, 1) + + +def FixImgCoordinates (images, boxes): + new_boxes = [] + if isinstance(images, list): + for i in range(len(images)): + # print(images[i].shape) + bbs = [] + for o_box in boxes[i] : + b = [None] * 4 + b[0] = int(o_box[0] * images[i].shape[0]) + b[1] = int(o_box[1] * images[i].shape[1]) + b[2] = int(o_box[2] * images[i].shape[0]) + b[3] = int(o_box[3] * images[i].shape[1]) + bbs.append(b) + + new_boxes.append(bbs) + else: + bbs = [] + for o_box in boxes[0] : + b = [None] * 4 + b[0] = int(o_box[0] * images.shape[0]) + b[1] = int(o_box[1] * images.shape[1]) + b[2] = int(o_box[2] * images.shape[0]) + b[3] = int(o_box[3] * images.shape[1]) + bbs.append(b) + + # this could be + # b[0] = int(o_box[0] * images.shape[0]) ==> b[0] = int(o_box[0] * images.shape[1]) + # b[1] = int(o_box[1] * images.shape[1]) ==> b[1] = int(o_box[1] * images.shape[0]) + # b[2] = int(o_box[2] * images.shape[0]) ==> b[2] = int(o_box[2] * images.shape[1]) + # b[3] = int(o_box[3] * images.shape[1]) ==> b[3] = int(o_box[3] * images.shape[0]) + + new_boxes.append(bbs) + + return new_boxes + + +class MathDetector(): + + def __init__(self, weight_path, args): + self._args = args + net = build_ssd(args, 'test', config.exp_cfg[args.cfg], 0, args.model_type, num_classes = 2) + self._net = net # nn.DataParallel(net) + weights = torch.load(weight_path, map_location = torch.device('cpu')) + + new_weights = OrderedDict() + for k, v in weights.items(): + name = k[7:] # remove `module.` + new_weights[name] = v + + self._net.load_state_dict(new_weights) + + if args.cuda and torch.cuda.is_available(): + self._net = self._net.cuda() + + self._net.eval() + + @torch.no_grad() + def Detect (self, thres, images): + cls = 1 # math class + boxes = [] + scores = [] + + if self._args.cuda and torch.cuda.is_available(): + images = images.cuda() + y, debug_boxes, debug_scores = self._net(images) # forward pass + + y, debug_boxes, debug_scores = y.cpu(), debug_boxes.cpu(), debug_scores.cpu() + detections = y.data + + for k in range(len(images)): + + img_boxes = [] + img_scores = [] + for j in range(detections.size(2)): + + if ( detections[k, cls, j, 0] < thres ): + continue + + pt = detections[k, cls, j, 1:] + coords = (pt[0], pt[1], pt[2], pt[3]) + img_boxes.append(coords) + img_scores.append(detections[k, cls, j, 0]) + + boxes.append(img_boxes) + scores.append(img_scores) + + return boxes, scores + + def ShowNetwork (self): + print(self._net) + + @torch.no_grad() + def DetectAny(self, image, thres): + if isinstance(image, list): + t_list = [_img_to_tensor(img) for img in image] + t = torch.stack(t_list, dim = 0) + else: + t = _img_to_tensor(image).unsqueeze(0) + # fix box coordinates to image pixel coordinates + boxes, scores = self.Detect(thres, t) + return FixImgCoordinates(image, boxes), scores + diff --git a/ScanSSD/gtdb/__init__.py b/ScanSSD/gtdb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ScanSSD/gtdb/adjust_boxes.py b/ScanSSD/gtdb/adjust_boxes.py new file mode 100644 index 0000000000000000000000000000000000000000..067b524f90fd33a1018ff7b4c5d27eacef19dfb7 --- /dev/null +++ b/ScanSSD/gtdb/adjust_boxes.py @@ -0,0 +1,220 @@ +# Author: Parag Mali +# This script stitches back the output generated on the image patches (sub-images) + +# read the image +import sys +sys.path.extend(['/home/psm2208/code', '/home/psm2208/code']) +import cv2 +import os +import csv +import numpy as np +import utils.visualize as visualize +from multiprocessing import Pool +from cv2.dnn import NMSBoxes +from scipy.ndimage.measurements import label +import scipy.ndimage as ndimage +import copy +from gtdb import fit_box +from gtdb import box_utils +from gtdb import feature_extractor +import shutil +import time +from collections import OrderedDict +from collections import deque +import argparse + +def parse_args(): + ''' + Parameters + ''' + parser = argparse.ArgumentParser( + description='Stitching method') + + parser.add_argument('--data_file', default='test', + type=str, help='choose one') + parser.add_argument('--output_dir', default='.', + help='Output directory path') + parser.add_argument('--math_dir', required=True, + type=str, help='detections dir') + parser.add_argument('--math_ext', default='.csv', + help='Extention of detection files') + parser.add_argument('--home_data', default='/home/psm2208/data/GTDB/', type = str, + help='database dir') + parser.add_argument('--home_eval', default='/home/psm2208/code/eval/', type = str, + help='Eval dir') + parser.add_argument('--home_images', default='/home/psm2208/data/GTDB/images/', type = str, + help='Images dir') + parser.add_argument('--home_anno', default='/home/psm2208/data/GTDB/annotations/', type = str, + help='Annotations dir') + parser.add_argument('--home_char', default='/home/psm2208/data/GTDB/char_annotations/', type = str, + help='Char anno dir') + parser.add_argument('--num_workers', default=4, type=int, help='Number of workers') + parser.add_argument('--type', default='math', type=str, help='Math or text') + + return parser.parse_args() + +def read_math(args, pdf_name): + ''' + Read math bounding boxes for given PDF + ''' + math_file = os.path.join(args.math_dir, pdf_name + args.math_ext) + data = np.array([]) + + if os.path.exists(math_file): + data = np.genfromtxt(math_file, delimiter=',') + + # if there is only one entry convert it to correct form required + if len(data.shape) == 1: + data = data.reshape(1, -1) + + if args.math_ext == '.char': + data = np.delete(data,1,1) + data = data[:,:5] + + return data.astype(int) + +def read_char(args, pdf_name): + ''' + Read char bounding boxes for given PDF + ''' + data = [] + + path = os.path.join(args.home_char, pdf_name + ".char") + + with open(path, 'r') as csvfile: + reader = csv.reader(csvfile, delimiter=',') + for row in reader: + #print('row is ' + str(row[1])) + # if entry is not in map + data.append(row) + + return np.array(data) + + +def adjust(params): + ''' + Fit the bounding boxes to the characters + ''' + args, math_regions, pdf_name, page_num = params + print('Processing ', pdf_name, ' > ', page_num) + + image = cv2.imread(os.path.join(args.home_images, pdf_name, str(int(page_num + 1)) + ".png")) + im_bw = fit_box.convert_to_binary(image) + + new_math = [] + for math in math_regions: + box = fit_box.adjust_box(im_bw, math) + + if feature_extractor.width(box) > 0 and feature_extractor.height(box) > 0: + new_math.append(box) + + return new_math + +def adjust_char(params): + ''' + Adjust the character bounding boxes + ''' + try: + args, char_regions, pdf_name, page_num = params + print('Char processing ', pdf_name, ' > ', page_num) + + image = cv2.imread(os.path.join(args.home_images, pdf_name, str(int(page_num) + 1) + ".png")) + im_bw = fit_box.convert_to_binary(image) + + new_chars = [] + + for char in char_regions: + + bb_char = [char[2],char[3],char[4],char[5]] + bb_char = [int(float(k)) for k in bb_char] + box = fit_box.adjust_box(im_bw, bb_char) + if feature_extractor.width(box) > 0 and feature_extractor.height(box) > 0: + char[1] = box[0] + char[2] = box[1] + char[3] = box[2] + char[4] = box[3] + new_chars.append(char) + + return new_chars + except Exception as e: + print('Error while processing ', pdf_name, ' > ', page_num, sys.exc_info()) + return [] + +def adjust_boxes(args): + ''' + Driving function for adjusting the boxes + ''' + + pdf_list = [] + pdf_names_file = open(args.data_file, 'r') + + for pdf_name in pdf_names_file: + pdf_name = pdf_name.strip() + + if pdf_name != '': + pdf_list.append(pdf_name) + + regions = {} + + for pdf_name in pdf_list: + if args.type == 'char': + regions[pdf_name] = read_char(args, pdf_name) + else: + regions[pdf_name] = read_math(args, pdf_name) + + voting_ip_list = [] + + for pdf_name in pdf_list: + + pages = np.unique(regions[pdf_name][:, 0]) + + #args, math_regions, pdf_name, page_num + for page_num in pages: + + current_math = regions[pdf_name][regions[pdf_name][:,0] == page_num] + + if args.type == 'math': + current_math = np.delete(current_math, 0, 1) + + voting_ip_list.append([args, current_math, pdf_name, page_num]) + + pool = Pool(processes=args.num_workers) + + if args.type == 'math': + out = pool.map(adjust, voting_ip_list) + else: + out = pool.map(adjust_char, voting_ip_list) + + for ip, final_math in zip(voting_ip_list, out): + pdf_name = ip[2] + page_num = ip[3] + + if args.type == 'math': + col = np.array([int(page_num)] * len(final_math)) + final_math = np.concatenate((col[:, np.newaxis], final_math), axis=1) + + math_file_path = os.path.join(args.output_dir, pdf_name + '.csv') + + if not os.path.exists(os.path.dirname(math_file_path)): + os.makedirs(os.path.dirname(math_file_path)) + + + if args.type == 'math': + math_file = open(math_file_path, 'a') + + np.savetxt(math_file, final_math, fmt='%.2f', delimiter=',') + math_file.close() + + else: + + with open(math_file_path, 'a') as csvfile: + writer = csv.writer(csvfile, delimiter=",") + + for math_region in final_math: + writer.writerow(math_region) + + +if __name__ == '__main__': + + args = parse_args() + adjust_boxes(args) diff --git a/ScanSSD/gtdb/box_utils.py b/ScanSSD/gtdb/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2f44d4967d4545e49acea88261f4c23c21295727 --- /dev/null +++ b/ScanSSD/gtdb/box_utils.py @@ -0,0 +1,61 @@ +# Author: Parag Mali +# This file contains many common operations on the bounding boxes like finding intersection +from collections import OrderedDict +from . import feature_extractor + +def check_inside(rectA, rectB): + + # returns True if A is inside B + if rectA[0] >= rectB[0] and\ + rectA[1] >= rectB[1] and\ + rectA[2] <= rectB[2] and\ + rectA[3] <= rectB[3]: + + return True + + return False + +# check if two rectangles intersect +def intersects(first, other): + return not (first[2] < other[0] or + first[0] > other[2] or + first[1] > other[3] or + first[3] < other[1]) + +def find_intersecting_boxes(math_regions, sorted=True): + # specific for fusion algo + # returns indices of intersecting math region for each math region + inter_map = {} + + for i in range(len(math_regions)): + inter_map[i] = [] + + for i in range(len(math_regions)): + for j in range(len(math_regions[i+1:])): + if intersects(math_regions[i], math_regions[i+1+j]): + inter_map[i].append(i+1+j) + #inter_map[i+1+j].append(i) + + inter_map = OrderedDict(inter_map) + + return inter_map + +def merge(box1, box2): + + final_box = [0,0,0,0] + final_box[0] = min(box1[0], box2[0]) # left + final_box[1] = min(box1[1], box2[1]) # top + final_box[2] = max(box1[2], box2[2]) # left + width + final_box[3] = max(box1[3], box2[3]) # top + height + + return final_box + + +if __name__ == '__main__': + box1 = [849.00,3797.00,1403.00,3890.00] + box2 = [1169.00,3804.00,1392.00,3886.00] + + print(intersects(box1, box2)) + print(gtdb.feature_extractor.inclusion(box1, box2)) + print(gtdb.feature_extractor.inclusion(box2, box1)) + print(gtdb.feature_extractor.iou(box2, box1)) diff --git a/ScanSSD/gtdb/calculate_means.py b/ScanSSD/gtdb/calculate_means.py new file mode 100644 index 0000000000000000000000000000000000000000..e02bf9b410afac09dfcec9acef6ad09b5e91a86c --- /dev/null +++ b/ScanSSD/gtdb/calculate_means.py @@ -0,0 +1,78 @@ +""" +Author: Parag Mali +Calculate the image per channel mean and standard +deviation in the training set, do not calculate the statistics on the +whole dataset, as per here http://cs231n.github.io/neural-networks-2/#datapre +""" + +import numpy as np +from os import listdir +from os.path import join, isdir +from glob import glob +import cv2 +import timeit +import sys +from multiprocessing import Pool +sys.path.extend(['/home/psm2208/code', '/home/psm2208/code']) + +# number of channels of the dataset image, 3 for color jpg, 1 for grayscale img +# you need to change it to reflect your dataset +CHANNEL_NUM = 3 + + +# The script assumes that under train_root, there are separate directories for each class +# of training images. +root = "/home/psm2208/data/GTDB/GTDB1/" +start = timeit.default_timer() + +def task(path): + + pixel_num = 0 # store all pixel number in the dataset + channel_sum = np.zeros(CHANNEL_NUM) + channel_sum_squared = np.zeros(CHANNEL_NUM) + + print('processing image ' + str(path)) + im = cv2.imread(path) # image in M*N*CHANNEL_NUM shape, channel in BGR order + im = im / 255.0 + pixel_num += (im.size / CHANNEL_NUM) + channel_sum += np.sum(im, axis=(0, 1)) + channel_sum_squared += np.sum(np.square(im), axis=(0, 1)) + + return (pixel_num, channel_sum, channel_sum_squared) + +def cal_dir_stat(root): + cls_dirs = [d for d in listdir(root) if isdir(join(root, d))] + + paths = [] + + for idx, d in enumerate(cls_dirs): + im_pths = glob(join(root, d, "*.png")) + paths.extend(im_pths) + + pool = Pool(processes=32) + ans = pool.map(task, paths) + pool.close() + pool.join() + + total_pixel_num = 0 + total_channel_sum = 0 + total_sum_sq = 0 + + for pixel_num, channel_sum, channel_sum_squared in ans: + total_channel_sum += channel_sum + total_pixel_num += pixel_num + total_sum_sq += channel_sum_squared + + bgr_mean = total_channel_sum / total_pixel_num + bgr_std = np.sqrt(total_sum_sq / total_pixel_num - np.square(bgr_mean)) + + # change the format from bgr to rgb + rgb_mean = list(bgr_mean)[::-1] + rgb_std = list(bgr_std)[::-1] + + return rgb_mean, rgb_std + +mean, std = cal_dir_stat(root) +end = timeit.default_timer() +print("elapsed time: {}".format(end - start)) +print("mean:{}\nstd:{}".format(mean, std)) diff --git a/ScanSSD/gtdb/create_dataset.py b/ScanSSD/gtdb/create_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..caeebeead9b59dc25d63fb1decc9839a627fef9c --- /dev/null +++ b/ScanSSD/gtdb/create_dataset.py @@ -0,0 +1,127 @@ +# Author: Parag Mali +# This script reads ground truths and normalizes them using image size + +# read the image +import sys +sys.path.extend(['/home/psm2208/code', '/home/psm2208/code']) +import cv2 +import os +import numpy as np +from multiprocessing import Pool +from gtdb import fit_box +from gtdb import feature_extractor +import argparse + +# Default parameters for thr GTDB dataset +def parse_args(): + + parser = argparse.ArgumentParser( + description='Stitching method') + + parser.add_argument('--data_file', default='test', + type=str, help='choose one') + parser.add_argument('--output_dir', default='.', + help='Output directory path') + parser.add_argument('--math_dir', required=True, + type=str, help='detections dir') + parser.add_argument('--math_ext', default='.csv', + help='Extention of detection files') + parser.add_argument('--home_data', default='/home/psm2208/data/GTDB/', type = str, + help='database dir') + parser.add_argument('--home_eval', default='/home/psm2208/code/eval/', type = str, + help='Eval dir') + parser.add_argument('--home_images', default='/home/psm2208/data/GTDB/images/', type = str, + help='Images dir') + parser.add_argument('--home_anno', default='/home/psm2208/data/GTDB/annotations/', type = str, + help='Annotations dir') + parser.add_argument('--home_char', default='/home/psm2208/data/GTDB/char_annotations/', type = str, + help='Char anno dir') + parser.add_argument('--num_workers', default=4, type=int, help='Number of workers') + + return parser.parse_args() + +def read_math(args, pdf_name): + + math_file = os.path.join(args.math_dir, pdf_name + args.math_ext) + data = np.array([]) + + if os.path.exists(math_file): + data = np.genfromtxt(math_file, delimiter=',') + + # if there is only one entry convert it to correct form required + if len(data.shape) == 1: + data = data.reshape(1, -1) + + if args.math_ext == '.char': + data = np.delete(data,1,1) + data = data[:,:5] + + return data.astype(int) + +def normalize(params): + + args, math_regions, pdf_name, page_num = params + print('Processing ', pdf_name, ' > ', page_num) + + image = cv2.imread(os.path.join(args.home_images, pdf_name, str(int(page_num + 1)) + ".png")) + im_bw = fit_box.convert_to_binary(image) + + new_math = [] + for math in math_regions: + box = [math[0]/im_bw.shape[1], math[1]/im_bw.shape[0], + math[2]/im_bw.shape[1], math[3]/im_bw.shape[0]]#fit_box.adjust_box(im_bw, math) + + if feature_extractor.width(box) > 0 and feature_extractor.height(box) > 0: + new_math.append(box) + + return new_math + +def normalize_boxes(args): + pdf_list = [] + pdf_names_file = open(args.data_file, 'r') + + for pdf_name in pdf_names_file: + pdf_name = pdf_name.strip() + + if pdf_name != '': + pdf_list.append(pdf_name) + + math_regions = {} + + for pdf_name in pdf_list: + math_regions[pdf_name] = read_math(args, pdf_name) + + voting_ip_list = [] + for pdf_name in pdf_list: + + pages = np.unique(math_regions[pdf_name][:, 0]) + + #args, math_regions, pdf_name, page_num + for page_num in pages: + current_math = math_regions[pdf_name][math_regions[pdf_name][:,0] == page_num] + voting_ip_list.append([args, np.delete(current_math, 0, 1), pdf_name, page_num]) + + pool = Pool(processes=args.num_workers) + out = pool.map(normalize, voting_ip_list) + + for ip, final_math in zip(voting_ip_list, out): + pdf_name = ip[2] + page_num = ip[3] + + col = np.array([int(page_num)] * len(final_math)) + final_math = np.concatenate((col[:, np.newaxis], final_math), axis=1) + + math_file_path = os.path.join(args.output_dir, pdf_name + '.csv') + + if not os.path.exists(os.path.dirname(math_file_path)): + os.makedirs(os.path.dirname(math_file_path)) + + math_file = open(math_file_path, 'a') + + np.savetxt(math_file, final_math, fmt='%.2f', delimiter=',') + math_file.close() + +if __name__ == '__main__': + + args = parse_args() + normalize_boxes(args) diff --git a/ScanSSD/gtdb/create_gt_math.py b/ScanSSD/gtdb/create_gt_math.py new file mode 100644 index 0000000000000000000000000000000000000000..a5be58106801c840a5f95d27db73cb65a97192d2 --- /dev/null +++ b/ScanSSD/gtdb/create_gt_math.py @@ -0,0 +1,273 @@ +# Author: Parag Mali +# This script reads ground truth to find the Symbol Layout Tree (SLT) bounding boxes + +import sys +sys.path.extend(['/home/psm2208/code', '/home/psm2208/code']) +import cv2 +import os +import csv +import numpy as np +from multiprocessing import Pool +import shutil + +def find_math(args): + + try: + pdf_name, image_file, char_file, page_num, output_file = args + + char_info = {} + char_map = {} + + image = cv2.imread(image_file) + + with open(char_file) as csvfile: + char_reader = csv.reader(csvfile, delimiter=',') + for row in char_reader: + char_info[row[1]] = row[2:] + + if row[-3] != 'NONE': + if row[1] not in char_map: + char_map[row[1]] = set() + + char_map[row[1]].add(row[-2]) + + if row[-2] not in char_map: + char_map[row[-2]] = set() + + char_map[row[-2]].add(row[1]) + + elif row[-4] == 'MATH_SYMBOL': + if row[1] not in char_map: + char_map[row[1]] = set() + + math_regions_chars = group_math(char_map) + math_regions = create_bb(math_regions_chars, char_info) + + multi_char_math = set({x for v in math_regions_chars for x in v}) + + os.makedirs(os.path.dirname(output_file), exist_ok=True) + writer = csv.writer(open(output_file,"a"), delimiter=",") + + + # with open(char_file) as csvfile: + # char_reader = csv.reader(csvfile, delimiter=',') + # + # for row in char_reader: + # if row[-1] in math_ocr and row[0] not in multi_char_math: + # math_regions.append([row[2],row[3],row[4],row[5]]) + + #math_regions = adjust_all(image, math_regions) + + for math_region in math_regions: + math_region.insert(0, int(page_num) - 1) + writer.writerow(math_region) + + print("Saved ", output_file, " > ", page_num, " math ->", len(math_regions)) + except: + print("Exception while processing ", pdf_name, " ", page_num, " ", sys.exc_info()) + + +def create_bb(math_regions_chars, char_info): + + math_regions = [] + + for region in math_regions_chars: + box = [] + + count = 0 + for char_id in region: + + if len(box) == 0: + box = [float(char_info[char_id][0]),float(char_info[char_id][1]), + float(char_info[char_id][2]), float(char_info[char_id][3])] + else: + box[0] = min(float(char_info[char_id][0]), box[0]) # left + box[1] = min(float(char_info[char_id][1]), box[1]) # top + box[2] = max(float(char_info[char_id][2]), box[2]) # left + width + box[3] = max(float(char_info[char_id][3]), box[3]) # top + height + + count = count + 1 + + box.append(count) + math_regions.append(box) + + return math_regions + +def group_math(char_map): + + visited = set() + regions = [] + + for key in char_map: + if key not in visited: + region = dfs(char_map, key) + regions.append(region) + + for k in region: + visited.add(k) + + return regions + + +def dfs(graph, start): + visited, stack = set(), [start] + while stack: + vertex = stack.pop() + if vertex not in visited: + visited.add(vertex) + stack.extend(graph[vertex] - visited) + return visited + + +def adjust_box(args): + im_bw, box = args + box = contract(im_bw, box) + box = expand(im_bw, box) + return box + +def contract(im_bw, box): + + # find first row with one pixel + rows_with_pixels = np.any(im_bw[box[1]:box[3], box[0]:box[2]], axis=1) + cols_with_pixels = np.any(im_bw[box[1]:box[3], box[0]:box[2]], axis=0) + + if len(rows_with_pixels==True) == 0 or len(cols_with_pixels==True) == 0: + box = [0,0,0,0,0] + return box + + left = box[0] + np.argmax(cols_with_pixels==True) + top = box[1] + np.argmax(rows_with_pixels==True) + right = box[0] + len(cols_with_pixels) - np.argmax(cols_with_pixels[::-1]==True) - 1 + bottom = box[1] + len(rows_with_pixels) - np.argmax(rows_with_pixels[::-1]==True) - 1 + + box[0] = left + box[1] = top + box[2] = right + box[3] = bottom + + return box + + # find first column with one pixel + # find last row with one pixel + # find last col with pixel + +def expand(im_bw, box): + + im_copy = np.copy(im_bw) + im_copy[box[1]:box[3], box[0]:box[2]] = 1 + + start = (box[1], box[0]) + queue = [start] + visited = set() + + while len(queue) != 0: + front = queue.pop(0) + if front not in visited: + for adjacent_space in get_adjacent_spaces(im_copy, front, visited): + queue.append(adjacent_space) + + box[0] = min(front[1], box[0]) #left + box[1] = min(front[0], box[1]) #top + box[2] = max(front[1], box[2]) # left + width + box[3] = max(front[0], box[3]) # top + height + + visited.add(front) + + return box + +def get_adjacent_spaces(im_bw, space, visited): + + spaces = list() + dirs = [[1,0],[-1,0],[0,1],[0,-1]] + + for dir in dirs: + r = space[0] + dir[0] + c = space[1] + dir[1] + + if r < im_bw.shape[0] and c < im_bw.shape[1] and r >= 0 and c >= 0: + spaces.append((r, c)) + + final = list() + for i in spaces: + if im_bw[i[0]][i[1]] == 1 and i not in visited: + final.append(i) + + return final + +def convert_to_binary(image): + + gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + im_bw = np.zeros(gray_image.shape) + im_bw[gray_image > 127] = 0 + im_bw[gray_image <= 127] = 1 + + return im_bw + +def adjust_all(image, boxes): + + im_bw = convert_to_binary(image) + adjusted = [] + + for box in boxes: + box = [int(box[0]), int(box[1]), int(box[2]), int(box[3])] + box = adjust_box((im_bw, box)) + adjusted.append(box) + + return adjusted + +def adjust_box(args): + im_bw, box = args + box = contract(im_bw, box) + box = expand(im_bw, box) + return box + +def create_gt_math(filename, image_dir, char_dir, output_dir="/home/psm2208/data/GTDB/annotationsV2/"): + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + pages_list = [] + pdf_names = open(filename, 'r') + + for pdf_name in pdf_names: + pdf_name = pdf_name.strip() + + if pdf_name != '': + + for root, dirs, files in os.walk(os.path.join(char_dir, pdf_name)): + for name in files: + if name.endswith(".pchar"): + + page_num = os.path.splitext(name)[0] + + pages_list.append((pdf_name, + os.path.join(image_dir, + pdf_name, + page_num + ".png"), + os.path.join(root, name), + int(page_num), + os.path.join(output_dir, + pdf_name + ".csv"))) + #page_num + ".pmath"))) + pdf_names.close() + + pool = Pool(processes=32) + pool.map(find_math, pages_list) + pool.close() + pool.join() + +if __name__ == "__main__": + home_data = "/home/psm2208/data/GTDB/" + home_eval = "/home/psm2208/code/eval/" + home_images = "/home/psm2208/data/GTDB/images/" + home_anno = "/home/psm2208/data/GTDB/annotations/" + home_char = "/home/psm2208/data/GTDB/char_annotations/" + output_dir = "/home/psm2208/code/eval/tt_samsung_train/" + + type = sys.argv[1] + + create_gt_math(home_data + type, home_images, home_char, output_dir) diff --git a/ScanSSD/gtdb/create_segmentation_gt.py b/ScanSSD/gtdb/create_segmentation_gt.py new file mode 100644 index 0000000000000000000000000000000000000000..8c209dbe98b318a192115b0f4f07e99788fcb196 --- /dev/null +++ b/ScanSSD/gtdb/create_segmentation_gt.py @@ -0,0 +1,171 @@ +# Rectangles after projection + +import sys +sys.path.extend(['/home/psm2208/code', '/home/psm2208/code']) +import os +import csv +import numpy as np +from multiprocessing import Pool +import shutil +from gtdb import feature_extractor + +def intersects(first, other): + return not (first[2] < other[0] or + first[0] > other[2] or + first[1] > other[3] or + first[3] < other[1]) + + +def create_gt(args): + count = 0 + + try: + output_dir, pdf_name, page_num, gt_page_math, det_page_math = args + + + inside_gt_dict = {} + + # for each det i, find the gt with which it intersects + for i, det in enumerate(det_page_math): + + inside_gt_dict[i] = set() + + for j, gt in enumerate(gt_page_math): + if intersects(det, gt): + inside_gt_dict[i].add(j) + + check_dict = {} + for i, gt in enumerate(gt_page_math): + + check_dict[i] = set() + + for j, det in enumerate(det_page_math): + if check_inside(det, gt): + check_dict[i].add(j) + + for key in check_dict: + if len(check_dict[key]) > 1: + count = count + 1 + + segmentation_gt = [] + + for i, det_math1 in enumerate(det_page_math): + + min = float('inf') + min_idx = -1 + + x1 = det_math1[0] + ((det_math1[2] - det_math1[0]) / 2) + y1 = det_math1[1] + ((det_math1[3] - det_math1[1]) / 2) + + for j, det_math in enumerate(det_page_math): + if i != j: + + x2 = det_math[0] + ((det_math[2] - det_math[0]) / 2) + y2 = det_math[1] + ((det_math[3] - det_math[1]) / 2) + + c_dist = np.sqrt((y2 - y1) * (y2 - y1) + (x2 - x1) * (x2 - x1))#feature_extractor.vertical_dist_bb(det_page_math[i], det_page_math[j]) + + if c_dist < min: + min = c_dist + min_idx = j + + if len(inside_gt_dict[i].intersection(inside_gt_dict[min_idx])) > 0: + # positive example + segmentation_gt.append( + feature_extractor.extract_features(det_page_math[i], det_page_math[min_idx], 1)) + else: + #negative example + segmentation_gt.append( + feature_extractor.extract_features(det_page_math[i], det_page_math[min_idx], 0)) + + output_file = os.path.join(output_dir, "gt.csv") + writer = csv.writer(open(output_file,"a"), delimiter=",") + + for gt_row in segmentation_gt: + writer.writerow(gt_row) + + print('Processed ', pdf_name, ' ', page_num) + + except: + print("Exception while processing ", pdf_name, " ", page_num, " ", sys.exc_info()) + + return count + +def create_gt_segmentation(filename, gt_math_dir, det_math_dir, output_dir): + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + pages_list = [] + pdf_names = open(filename, 'r') + + for pdf_name in pdf_names: + print('Processing-1', pdf_name) + pdf_name = pdf_name.strip() + + if pdf_name != '': + gt_math_file = os.path.join(gt_math_dir, pdf_name + ".csv") + gt_math_regions = np.genfromtxt(gt_math_file, delimiter=',', dtype=int) + + det_math_file = os.path.join(det_math_dir, pdf_name + ".csv") + det_math_regions = np.genfromtxt(det_math_file, delimiter=',', dtype=int) + + pages = np.unique(gt_math_regions[:, 0]) + + for page_num in pages: + + gt_page_math = gt_math_regions[np.where(gt_math_regions[:,0]==page_num)] + gt_page_math = gt_page_math[:,1:] + + det_page_math = det_math_regions[np.where(det_math_regions[:, 0] == page_num)] + det_page_math = det_page_math[:, 1:] + + pages_list.append([output_dir, pdf_name, page_num, gt_page_math, det_page_math]) + + pdf_names.close() + + pool = Pool(processes=1) + result = pool.map(create_gt, pages_list) + pool.close() + pool.join() + print('Merged regions', np.sum(result)) + + +def check_inside(rectA, rectB): + + # returns True if A is inside B + #left, top, right, bottom + #If any of the sides from A are outside of B + if rectA[3] < rectB[1]: # if bottom of rectA is less than top of rectB + return False + if rectA[1] > rectB[3]: # if top of rectA is greater than bottom of rectB + return False + if rectA[2] < rectB[0]: # if right of rectA is less than left of rectB + return False + if rectA[0] > rectB[2]: # if left of rectangleA is greater than right of rectB + return False + + #If none of the sides from A are outside B + return True + +if __name__ == "__main__": + home_data = "/home/psm2208/data/GTDB/" + home_eval = "/home/psm2208/code/eval/" + home_images = "/home/psm2208/data/GTDB/images/" + home_anno = "/home/psm2208/data/GTDB/annotations/" + home_char = "/home/psm2208/data/GTDB/char_annotations/" + + output_dir = "/home/psm2208/code/eval/segmentation_gt/" + gt_math = "/home/psm2208/Workspace/Task3_Detection/Train/GT_math_csv/" + + det_math = "/home/psm2208/code/eval/Train3_Focal_10_25/equal_30.0" + + type = sys.argv[1] + + + #filename, gt_math_dir, det_math_dir, output_dir + create_gt_segmentation(home_data + type, gt_math, det_math, output_dir) + diff --git a/ScanSSD/gtdb/diagnose.py b/ScanSSD/gtdb/diagnose.py new file mode 100644 index 0000000000000000000000000000000000000000..c161a87b849eb94f1e39da704ba22f1dadaa879d --- /dev/null +++ b/ScanSSD/gtdb/diagnose.py @@ -0,0 +1,353 @@ +# Author: Parag Mali +# This file contains functions that calculate character level detection results + +import os +import sys +sys.path.extend(['/home/psm2208/code', '/home/psm2208/code']) + +import csv +from multiprocessing import Pool +from IOU_lib import IOUevaluater +import copy +from gtdb import box_utils + +# check if two rectangles intersect +def intersects(first, other): + return not (first[2] < other[0] or + first[0] > other[2] or + first[1] > other[3] or + first[3] < other[1]) + +def read_data(training_pdf_names, char_dir, gt_math_dir, det_math_dir): + + char_bbs = {} + args = [] + total_math_char = 0 + + for filename in training_pdf_names: + + path = os.path.join(char_dir, filename + ".csv") + print('Processing ' + path) + count = 0 + + #data = adjust_data[filename] + + map = {} + with open(path, 'r') as csvfile: + reader = csv.reader(csvfile, delimiter=',') + for row in reader: + #print('row is ' + str(row[1])) + # if entry is not in map + if str(int(float(row[0]))) not in map: + map[str(int(float(row[0])))] = [] + + if row[6] == 'MATH_SYMBOL': + total_math_char = total_math_char + 1 + + # row[2] = data[count][1] + # row[3] = data[count][2] + # row[4] = data[count][3] + # row[5] = data[count][4] + + map[str(int(float(row[0])))].append(row) + count = count + 1 + + char_bbs[filename] = map + + det_math_bbs = {} + + for filename in training_pdf_names: + + #path = os.path.join(math_dir, filename + ".math") + path = os.path.join(det_math_dir, filename + ".csv") + + map = {} + with open(path, 'r') as csvfile: + reader = csv.reader(csvfile, delimiter=',') + for row in reader: + # if entry is not in map + if str(int(float(row[0]))) not in map: + map[str(int(float(row[0])))] = [] + + + map[str(int(float(row[0])))].append(row) + + det_math_bbs[filename] = map + + gt_math_bbs = {} + + for filename in training_pdf_names: + + path = os.path.join(gt_math_dir, filename + ".csv") + + map = {} + with open(path, 'r') as csvfile: + reader = csv.reader(csvfile, delimiter=',') + for row in reader: + # if entry is not in map + if str(int(float(row[0]))) not in map: + map[str(int(float(row[0])))] = [] + + map[str(int(float(row[0])))].append(row) + + gt_math_bbs[filename] = map + + return training_pdf_names, total_math_char, gt_math_bbs, det_math_bbs, char_bbs + + +def char_level_eval(training_pdf_names, total_math_char, gt_math_bbs, det_math_bbs, char_bbs): + + args = [] + + for key in det_math_bbs: + for page in det_math_bbs[key]: + if page not in gt_math_bbs[key]: + gt_math = [] + else: + gt_math = gt_math_bbs[key][page] + args.append([key, det_math_bbs[key][page], char_bbs[key][page], gt_math]) + + pool = Pool(processes=16) + ans = pool.map(character_level_score, args) + pool.close() + pool.join() + + detected_math_char = 0 + detected_text_char = 0 + + for math, text in ans: + detected_math_char = detected_math_char + math + detected_text_char = detected_text_char + text + + print('detected math chars ', detected_math_char) + print('detected text chars ', detected_text_char) + print('total math chars ', total_math_char) + + recall = detected_math_char / total_math_char + precision = detected_math_char / (detected_math_char + detected_text_char) + + fscore = 2 * recall * precision / (recall + precision) + + print('Char Recall\t', recall) + print('Char Precision\t', precision) + print('Char F-Score\t', fscore) + + +def character_level_score(args): + + filename, det_math_bbs, char_bbs, gt_math_bbs = args + detected_math_char_count = 0 + text_char_count = 0 + + for char_info in char_bbs: + char_bb = [float(char_info[1]), float(char_info[2]), float(char_info[3]), float(char_info[4])] + + for current_math_bb in det_math_bbs: + + math_bb = [float(current_math_bb[1]),float(current_math_bb[2]), + float(current_math_bb[3]),float(current_math_bb[4])] + + if box_utils.check_inside(char_bb, math_bb): #TODO + + if char_info[6] == 'MATH_SYMBOL': + detected_math_char_count = detected_math_char_count + 1 + break + else: + text_char_count = text_char_count + 1 + + return detected_math_char_count, text_char_count + +def box_level_granular_eval(training_pdf_names, total_math_char, gt_math_dir, det_math_dir, + gt_math_bbs, det_math_bbs, char_bbs, test_gt_math_dir): + + _, _, detailed_detections = IOUevaluater.IOUeval(test_gt_math_dir, det_math_dir) + assign_chars_to_math_boxes(gt_math_bbs, char_bbs) + assign_chars_to_math_boxes(det_math_bbs, char_bbs) + + single_char_det = [0, 0] + multi_char_det = [0, 0] + + total_single_char_det = 0 + total_multi_char_det = 0 + + single_char_gt = 0 + multi_char_gt = 0 + + for filename in training_pdf_names: + current_det = detailed_detections[filename] + + for page in current_det[0]: + coarse = current_det[0][page] + fine = current_det[1][page] + + #DET for recall + for det in coarse: + if gt_math_bbs[filename][str(int(float(page)))][int(det[3:])-1][5] > 1: + multi_char_det[0] = multi_char_det[0] + 1 + else: + single_char_det[0] = single_char_det[0] + 1 + + for det in fine: + if gt_math_bbs[filename][str(int(float(page)))][int(det[3:])-1][5] > 1: + multi_char_det[1] = multi_char_det[1] + 1 + else: + single_char_det[1] = single_char_det[1] + 1 + + # DET for precision + for det in det_math_bbs[filename][str(int(float(page)))]: + if det[5] > 1: + total_multi_char_det = total_multi_char_det + 1 + else: + total_single_char_det = total_single_char_det + 1 + + #TODO + # for gt in gt_math_bbs[filename][str(int(float(page)))]: + # if gt[5] == 1: + # single_char_gt = single_char_gt + 1 + # else: + # multi_char_gt = multi_char_gt + 1 + + # GT + for page in gt_math_bbs[filename]: + for gt in gt_math_bbs[filename][str(int(float(page)))]: + if gt[5] > 1: + multi_char_gt = multi_char_gt + 1 + else: + single_char_gt = single_char_gt + 1 + + # single char scores - coarse + # precision + print("Number of single character regions correctly detected IOU50, IOU75 ", single_char_det) + print("Total number of single character regions detected ", total_single_char_det) + print("Total number of single character regions GT ", single_char_gt) + + print("Number of multi character regions correctly detected IOU50, IOU75 ", multi_char_det) + print("Total number of multi character regions detected ", total_multi_char_det) + print("Total number of multi character regions GT ", multi_char_gt) + + # Single character regions + + print("***** Results : Single Character Regions ***** ") + prec_50 = single_char_det[0]/total_single_char_det + rec_50 = single_char_det[0] / single_char_gt + fscore_50 = 2*prec_50*rec_50/(prec_50 + rec_50) + + print("Precision IOU50 ", prec_50) + print("Recall IOU50 ", rec_50) + print("F-score IOU50 ", fscore_50) + + prec_75 = single_char_det[1] / total_single_char_det + rec_75 = single_char_det[1] / single_char_gt + fscore_75 = 2 * prec_75 * rec_75 / (prec_75 + rec_75) + + print("Precision IOU75 ", prec_75) + print("Recall IOU75 ", rec_75) + print("F-score IOU75 ", fscore_75) + + print("***** Results : Multi Character Regions ***** ") + prec_50 = multi_char_det[0] / total_multi_char_det + rec_50 = multi_char_det[0] / multi_char_gt + fscore_50 = 2 * prec_50 * rec_50 / (prec_50 + rec_50) + + print("Precision IOU50 ", prec_50) + print("Recall IOU50 ", rec_50) + print("F-score IOU50 ", fscore_50) + + prec_75 = multi_char_det[1] / total_multi_char_det + rec_75 = multi_char_det[1] / multi_char_gt + fscore_75 = 2 * prec_75 * rec_75 / (prec_75 + rec_75) + + print("Precision IOU75 ", prec_75) + print("Recall IOU75 ", rec_75) + print("F-score IOU75 ", fscore_75) + +def find_merged_regions(training_pdf_names, gt_math_boxes, det_math_boxes): + + det_regions_with_multi_math = 0 + + for pdf_name in training_pdf_names: + for page in det_math_boxes[pdf_name]: + + for det in det_math_boxes[pdf_name][page]: + + count = 0 + + det_bb = [float(det[1]), float(det[2]), + float(det[3]), float(det[4])] + + if page not in gt_math_boxes[pdf_name]: + continue + + for gt in gt_math_boxes[pdf_name][page]: + + gt_bb = [float(gt[1]),float(gt[2]), + float(gt[3]),float(gt[4])] + + if box_utils.check_inside(gt_bb, det_bb): + count = count + 1 + + if count > 1: + det_regions_with_multi_math = \ + det_regions_with_multi_math + count + break + + print("Merged boxes ", det_regions_with_multi_math) + + +def assign_chars_to_math_boxes(all_math_boxes, all_char_bbs): + + for pdf_name in all_math_boxes: + for page in all_math_boxes[pdf_name]: + + #print('Assigning ', pdf_name, page) + math_boxes = all_math_boxes[pdf_name][page] + char_bbs = all_char_bbs[pdf_name][str(int(float(page)))] + + for math_box in math_boxes: + math_box.append(0) + + for char_info in char_bbs: + for math_bb in math_boxes: + + current_char_bb = [float(char_info[1]), float(char_info[2]), #TODO index from 1 + float(char_info[3]), float(char_info[4])] + + current_math_bb = [float(math_bb[1]),float(math_bb[2]), + float(math_bb[3]),float(math_bb[4])] + + if box_utils.check_inside(current_char_bb, current_math_bb): + math_bb[-1] = math_bb[-1] + 1 + + +if __name__ == '__main__': + + training_pdf_names = open(sys.argv[1], 'r') + + training_pdf_names_list = [] + + # for each training image pdf file + for pdf_name in training_pdf_names: + pdf_name = pdf_name.strip() + if pdf_name != '': + training_pdf_names_list.append(pdf_name) + training_pdf_names.close() + + detected_math_dir = sys.argv[2] #'/home/psm2208/code/eval/final_submission/Test/' + gt_math_dir = sys.argv[3] # '/home/psm2208/data/GTDB/annotations/' + + gt_char_dir = sys.argv[4]#'/home/psm2208/data/GTDB/char_annotations/' + test_gt_math_dir = sys.argv[5] #/home/psm2208/Workspace/Task3_Detection/Test/test_math/ + + image_dir = '/home/psm2208/data/GTDB/images/' + + training_pdf_names, total_math_char, gt_math_bbs, det_math_bbs, char_bbs = \ + read_data(training_pdf_names_list, gt_char_dir, gt_math_dir, detected_math_dir) + + char_level_eval(training_pdf_names, total_math_char, copy.deepcopy(gt_math_bbs), + copy.deepcopy(det_math_bbs), copy.deepcopy(char_bbs)) + + box_level_granular_eval(training_pdf_names, total_math_char, gt_math_dir, + detected_math_dir, gt_math_bbs, + det_math_bbs, char_bbs,test_gt_math_dir) + + find_merged_regions(training_pdf_names, copy.deepcopy(gt_math_bbs), copy.deepcopy(det_math_bbs)) diff --git a/ScanSSD/gtdb/feature_extractor.py b/ScanSSD/gtdb/feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..b07a4bfac967e55fecd2b0d2b7e937cf2bcb2d38 --- /dev/null +++ b/ScanSSD/gtdb/feature_extractor.py @@ -0,0 +1,92 @@ +# Author: Parag Mali +# This file contains functions to extract a set of features from two given +# bounding boxes (eg. Geometric features) + +import math + +def extract_features(box1, box2, label=1, test=False): + + features = [area(box1)/area(box2), + height(box1)/height(box2), + width(box1)/width(box2), + center_dist(box1, box2), + horizontal_dist_center(box1, box2), + vertical_dist_center(box1, box2), + vertical_dist_bb(box1, box2), + horizontal_indentation(box1, box2), + aspect_ratio(box1), + aspect_ratio(box2)] # bottom of box to top of box, min + + if not test: + features.append(label) + + return features + + +def intersection(boxA, boxB): + # determine the (x, y)-coordinates of the intersection rectangle + xA = max(boxA[0], boxB[0]) + yA = max(boxA[1], boxB[1]) + xB = min(boxA[2], boxB[2]) + yB = min(boxA[3], boxB[3]) + + # compute the area of intersection rectangle + interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) + + return interArea + +def union(boxA, boxB): + # compute the area of both the prediction and ground-truth + # rectangles + boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) + boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) + + # compute the intersection over union by taking the intersection + # area and dividing it by the sum of prediction + ground-truth + # areas - the interesection area + return float(boxAArea + boxBArea - intersection(boxA, boxB)) + + +def iou(boxA, boxB): + return intersection(boxA, boxB) / union(boxA, boxB) + +def inclusion(box1, box2): + return intersection(box1, box2) / area(box1) + +def aspect_ratio(box): + return width(box)/height(box) + +def horizontal_indentation(box1, box2): + return abs(box1[0]-box2[0]) + +def width(box): + return box[2] - box[0] + +def height(box): + return box[3] - box[1] + +def area(box): + return width(box) * height(box) + +def center_dist(box1, box2): + x1 = box1[0] + (width(box1) / 2) + y1 = box1[1] + (height(box1) / 2) + x2 = box2[0] + (width(box2) / 2) + y2 = box2[1] + (height(box2) / 2) + + return math.sqrt((y2 - y1) * (y2 - y1) + (x2 - x1) * (x2 - x1)) + +def horizontal_dist_center(box1, box2): + x1 = box1[0] + (width(box1) / 2) + x2 = box2[0] + (width(box2) / 2) + + return abs(x2-x1) + +def vertical_dist_center(box1, box2): + y1 = box1[1] + (height(box1) / 2) + y2 = box2[1] + (height(box2) / 2) + + return abs(y2-y1) + +def vertical_dist_bb(box1, box2): + return min(abs(box1[3]-box2[1]), abs(box2[3]-box1[1])) diff --git a/ScanSSD/gtdb/fit_box.py b/ScanSSD/gtdb/fit_box.py new file mode 100644 index 0000000000000000000000000000000000000000..3561f4b1ad4e58fc73a98417318052c8ce4d41d6 --- /dev/null +++ b/ScanSSD/gtdb/fit_box.py @@ -0,0 +1,95 @@ +# Author: Parag Mali +# This file performs postprocessing on the detection results +# so that it perfectly contains the connected components + +import numpy as np +import cv2 + +def convert_to_binary(image): + # convert image to binary + + gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + im_bw = np.zeros(gray_image.shape) + im_bw[gray_image > 127] = 0 + im_bw[gray_image <= 127] = 1 + + return im_bw + +def adjust_box_p(args): + im_bw, box = args + return adjust_box(im_bw, box) + +def adjust_box(im_bw, box): + # expand or contract the bounding box to fit the math expression + box = [int(np.round(x)) for x in box] + box = contract(im_bw, box) + box = expand(im_bw, box) + return box + +def contract(im_bw, box): + + # find first row with one pixel + rows_with_pixels = np.any(im_bw[box[1]:box[3], box[0]:box[2]], axis=1) + cols_with_pixels = np.any(im_bw[box[1]:box[3], box[0]:box[2]], axis=0) + + if len(rows_with_pixels==True) == 0 or len(cols_with_pixels==True) == 0: + box = [0,0,0,0] + return box + + left = box[0] + np.argmax(cols_with_pixels==True) + top = box[1] + np.argmax(rows_with_pixels==True) + right = box[0] + len(cols_with_pixels) - np.argmax(cols_with_pixels[::-1]==True) - 1 + bottom = box[1] + len(rows_with_pixels) - np.argmax(rows_with_pixels[::-1]==True) - 1 + + box[0] = left + box[1] = top + box[2] = right + box[3] = bottom + + return box + +def expand(im_bw, box): + + im_copy = np.copy(im_bw) + im_copy[box[1]:box[3], box[0]:box[2]] = 1 + + start = (box[1], box[0]) + queue = [start] + visited = set() + + while len(queue) != 0: + front = queue.pop(0) + if front not in visited: + for adjacent_space in get_adjacent_spaces(im_copy, front, visited): + queue.append(adjacent_space) + + box[0] = min(front[1], box[0]) #left + box[1] = min(front[0], box[1]) #top + box[2] = max(front[1], box[2]) # left + width + box[3] = max(front[0], box[3]) # top + height + + visited.add(front) + + return box + + +def get_adjacent_spaces(im_bw, space, visited): + + spaces = list() + dirs = [[1,0],[-1,0],[0,1],[0,-1]] + + for dir in dirs: + r = space[0] + dir[0] + c = space[1] + dir[1] + + if r < im_bw.shape[0] and c < im_bw.shape[1] and r >= 0 and c >= 0: + spaces.append((r, c)) + + final = list() + for i in spaces: + if im_bw[i[0]][i[1]] == 1 and i not in visited: + final.append(i) + + return final + diff --git a/ScanSSD/gtdb/gen_training_ids.py b/ScanSSD/gtdb/gen_training_ids.py new file mode 100644 index 0000000000000000000000000000000000000000..216dd931ac096afc2dfac1569dd87823afea5ed1 --- /dev/null +++ b/ScanSSD/gtdb/gen_training_ids.py @@ -0,0 +1,36 @@ +# Author: Parag Mali +# This script splits training data into 80% training and 20% validation + +def gen_training_ids(): + + dataset = [('Burstall77',24,[23,20,3,5,15]), ('BAMS_1998_123_143',21,[2,17,12,11]), + ('AIF_1999_375_404',30,[16,17,18,9,24,15]), ('ASENS_1997_367_384',18,[11,17,7,6]), + ('Brezis83',5,[5]), ('MA_1977_275_292',18,[4,2,5,7]), ('Borcherds86',4,[2]), + ('BAMS_1971_1974_1',3,[1]), ('BAMS_1971_1974_2',4,[4]), ('MA_1999_175_196',22,[19,1,10,16]), + ('JMS_1975_497_506',10,[6,3]), ('JMKU_1971_377_379',3,[3]), ('BAMS_1971_1974_3',4,[1]), + ('AnnM_1970_550_569',20,[7,18,8,17]), ('AIF_1970_493_498',6,[1]), ('JMS_1975_281_288',8,[8,7]), + ('TMJ_1990_163_193',32,[1,23,31,28,5,9]), ('TMJ_1973_317_331',16,[7,11,16]), + ('MA_1970_26_38',13,[7,1,12]), ('InvM_1999_163_181',19,[18,10,12,2]), + ('InvM_1970_121_134',14,[9,3,2]), ('BSMF_1970_165_192',28,[22,20,7,27,23,13]), + ('ActaM_1998_283_305',23,[15,17,20,6,10]), ('ASENS_1970_273_284',12,[5,7]), + ('TMJ_1973_333_338',6,[2]), ('Cline88',15,[14,9,11]), + ('ActaM_1970_37_63',27,[21,19,8,27,23]), ('JMS_1975_289_293',5,[4]), + ('BSMF_1998_245_271',27,[11,1,25,3,10]), ('Alford94',20,[10,11,5,13]), + ('KJM_1999_17_36',20,[11,6,3,10]), ('JMKU_1971_181_194',14,[8,14,1]), + ('Bergweiler83',37,[18,11,9,24,34,13,1]), ('Arkiv_1997_185_199',15,[9,6,2]), + ('Arkiv_1971_141_163',23,[16,7,1,10,5]), ('JMKU_1971_373_375',3,[2])] + + print("Training dataset.....") + for filename, page_count, val_pages in dataset: + for i in range(1,page_count+1): + if i not in val_pages: + print(filename + "/" + str(i)) + + print("Validation dataset.....") + for filename, page_count, val_pages in dataset: + for i in range(1, page_count + 1): + if i in val_pages: + print(filename + "/" + str(i)) + +if __name__ == '__main__': + gen_training_ids() diff --git a/ScanSSD/gtdb/generate_subimages.py b/ScanSSD/gtdb/generate_subimages.py new file mode 100644 index 0000000000000000000000000000000000000000..14b9ca5ada17a440bb5adcf70788f2f7545a36f2 --- /dev/null +++ b/ScanSSD/gtdb/generate_subimages.py @@ -0,0 +1,319 @@ +# Author: Parag Mali +# This script divides a big image into smaller images using a sliding window. +# It also divides corresponding bounding box annotations. +# NOTE: This file is no longer needed as sub-images are generated in data-loader + +# read the image +import numpy as np +import cv2 +import copy +import os +import sys +from multiprocessing import Pool + +# Default parameters for thr GTDB dataset +intermediate_width = 4800 +intermediate_height = 6000 +crop_size = 1200 +final_width = 0 +final_height = 0 +stride = 0.1 + +n_horizontal = int(intermediate_width / crop_size) # 4 +n_vertical = int(intermediate_height / crop_size) # 5 + +math_dir='/home/psm2208/data/GTDB/annotations/' +char_dir='/home/psm2208/data/GTDB/char_annotations/' +image_dir='/home/psm2208/data/GTDB/images/' +output_image_dir='/home/psm2208/data/GTDB/processed_images_/' +output_math_dir='/home/psm2208/data/GTDB/processed_annotations_/' +output_char_dir='/home/psm2208/data/GTDB/processed_char_annotations_/' + +# This function generates sub-images +def generate_subimages(pdf_name ='Alford94'): + + # find all the images + image_filelist = [file for file in os.listdir(os.path.join(image_dir, pdf_name)) if file.endswith('.png')] + + # math annotations + math_filepath = os.path.join(math_dir, pdf_name + ".math") + math_file_present = os.path.isfile(math_filepath) + + # char annotations + char_filepath = os.path.join(char_dir, pdf_name + ".char") + char_file_present = os.path.isfile(char_filepath) + + if math_file_present: + math_file = open(math_filepath, 'r') + + boxes = {} + + for line in math_file: + box = line.split(",") + idx = int(box[0]) + 1 + box = box[1:] + + box = list(map(float, box)) + box = list(map(int, box)) + + if idx not in boxes: + boxes[idx] = [] + + boxes[idx].append(box) + + if char_file_present: + char_file = open(char_filepath, 'r') + + char_boxes = {} + + for line in char_file: + char_box = line.split(",") + idx = int(char_box[0]) + 1 + char_box = char_box[2:] + + #box = list(map(int, box)) + + if idx not in char_boxes: + char_boxes[idx] = [] + + char_boxes[idx].append(char_box) + + + for image_filepath in image_filelist: + + image = cv2.imread(os.path.join(image_dir, pdf_name, image_filepath)) + basename = os.path.basename(image_filepath) + page_id = int(os.path.splitext(basename)[0]) + + original_width = image.shape[1] + original_height = image.shape[0] + + intermediate_width_ratio = intermediate_width / original_width + intermediate_height_ratio = intermediate_height / original_height + + image = cv2.resize(image, (intermediate_width, intermediate_height)) + + if math_file_present: + if page_id in boxes: + current_boxes = boxes[page_id] + else: + current_boxes = [] + + # preprocess the boxes + for box in current_boxes: + + box[0] = box[0] * intermediate_width_ratio + box[1] = box[1] * intermediate_height_ratio + box[2] = box[2] * intermediate_width_ratio + box[3] = box[3] * intermediate_height_ratio + + if char_file_present: + if page_id in char_boxes: + current_char_boxes = char_boxes[page_id] + else: + current_char_boxes = [] + + # preprocess the boxes + for box in current_char_boxes: + + box[0] = float(box[0]) * intermediate_width_ratio + box[1] = float(box[1]) * intermediate_height_ratio + box[2] = float(box[2]) * intermediate_width_ratio + box[3] = float(box[3]) * intermediate_height_ratio + + + subimg_id = 1 + + # create required dirs + if not os.path.exists(os.path.join(output_image_dir, pdf_name, str(page_id))): + os.makedirs(os.path.join(output_image_dir, pdf_name, str(page_id))) + + if not os.path.exists(os.path.join(output_math_dir, pdf_name, str(page_id))): + os.makedirs(os.path.join(output_math_dir, pdf_name, str(page_id))) + + if not os.path.exists(os.path.join(output_char_dir, pdf_name, str(page_id))): + os.makedirs(os.path.join(output_char_dir, pdf_name, str(page_id))) + + for i in np.arange(0, n_vertical-1+stride, stride): + for j in np.arange(0, n_horizontal-1+stride, stride): + + print('Processing sub image : ', subimg_id) + + if math_file_present: + out_math_file = os.path.join(output_math_dir, pdf_name, str(page_id), str(subimg_id) + ".pmath") + out_math = open(out_math_file, "w") + + if char_file_present: + out_char_file = os.path.join(output_char_dir, pdf_name, str(page_id), str(subimg_id) + ".pchar") + out_char = open(out_char_file, "w") + + x_l = int(np.round(crop_size * i)) + x_h = int(np.round(crop_size * (i + 1))) + + y_l = int(np.round(crop_size * j)) + y_h = int(np.round(crop_size * (j + 1))) + + cropped_image = image[x_l: x_h, y_l: y_h, :] + cropped_image = cv2.resize(cropped_image, (final_width, final_height)) + + # left, top, right, bottom + image_box = [y_l, x_l, y_h, x_h] + + # find scaling factors + final_width_ratio = final_width / (y_h - y_l) + final_height_ratio = final_height / (x_h - x_l) + + count = 0 + + if math_file_present: + if page_id in boxes: + current_page_boxes = copy.deepcopy(boxes[page_id]) + else: + current_page_boxes = [] + + # if math intersects only consider the region which + # is part of the current bounding box + for box in current_page_boxes: + if intersects(image_box, box): + #print('intersects ', box) + + # left, top, right, bottom + # y increases downwards + + #crop the boxes to fit into image region + box[0] = max(y_l, box[0]) + box[1] = max(x_l, box[1]) + box[2] = min(y_h, box[2]) + box[3] = min(x_h, box[3]) + + # Translate to origin + box[0] = box[0] - y_l + box[2] = box[2] - y_l + + box[1] = box[1] - x_l + box[3] = box[3] - x_l + + # scaling + box[2] = int(np.round(box[2] * final_width_ratio)) + box[0] = int(np.round(box[0] * final_width_ratio)) + + box[3] = int(np.round(box[3] * final_height_ratio)) + box[1] = int(np.round(box[1] * final_height_ratio)) + + count = count + 1 + out_math.write(','.join(str(x) for x in box)) + out_math.write("\n") + #cv2.rectangle(cropped_image, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 3) + + if char_file_present: + if page_id in char_boxes: + current_page_char_boxes = copy.deepcopy(char_boxes[page_id]) + else: + current_page_char_boxes = [] + + # if math intersects only consider the region which + # is part of the current bounding box + for box in current_page_char_boxes: + if intersects(image_box, box): + #print('intersects ', box) + + # left, top, right, bottom + # y increases downwards + + #crop the boxes to fit into image region + box[0] = max(y_l, box[0]) + box[1] = max(x_l, box[1]) + box[2] = min(y_h, box[2]) + box[3] = min(x_h, box[3]) + + # Translate to origin + box[0] = box[0] - y_l + box[2] = box[2] - y_l + + box[1] = box[1] - x_l + box[3] = box[3] - x_l + + # scaling + box[2] = int(np.round(box[2] * final_width_ratio)) + box[0] = int(np.round(box[0] * final_width_ratio)) + + box[3] = int(np.round(box[3] * final_height_ratio)) + box[1] = int(np.round(box[1] * final_height_ratio)) + + out_char.write(','.join(str(x) for x in box)) + #out_char.write("\n") + + fig_name = os.path.join(output_image_dir, pdf_name, str(page_id), str(subimg_id) + ".png") + print("Saving " + fig_name) + cv2.imwrite(fig_name, cropped_image) + subimg_id = subimg_id + 1 + + if math_file_present: + + out_math.close() + out_char.close() + if count == 0: + # if no math regions, delete the file + os.remove(out_math_file) + os.remove(out_char_file) + + +# check if two rectangles intersect +def intersects(first, other): + return not (first[2] < other[0] or + first[0] > other[2] or + first[1] > other[3] or + first[3] < other[1]) + +if __name__ == '__main__': + + training_pdf_names = open(sys.argv[1], 'r') + stride = float(sys.argv[2]) # 0.1 + + training_pdf_names_list = [] + + # for each training image pdf file + for pdf_name in training_pdf_names: + pdf_name = pdf_name.strip() + if pdf_name != '': + training_pdf_names_list.append(pdf_name) + + if not os.path.exists(os.path.join(output_image_dir, pdf_name)): + os.makedirs(os.path.join(output_image_dir, pdf_name)) + + if not os.path.exists(os.path.join(output_math_dir, pdf_name)): + os.makedirs(os.path.join(output_math_dir, pdf_name)) + + if not os.path.exists(os.path.join(output_char_dir, pdf_name)): + os.makedirs(os.path.join(output_char_dir, pdf_name)) + + training_pdf_names.close() + + suffix = sys.argv[3] #str(int(100 * stride)) + + final_height = int(suffix) + final_width = int(suffix) + + char_dir = '/home/psm2208/data/GTDB/char_annotations/' + math_dir = '/home/psm2208/code/eval/relations_adjust/' + #math_dir = '/home/psm2208/data/GTDB/annotations/' + image_dir = '/home/psm2208/data/GTDB/images/' + + output_image_dir = '/home/psm2208/data/GTDB/processed_images_' + suffix + output_math_dir = '/home/psm2208/data/GTDB/processed_annotations_' + suffix + output_char_dir = '/home/psm2208/data/GTDB/processed_char_annotations_' + suffix + + # create required dirs + if not os.path.exists(output_image_dir): + os.makedirs(output_image_dir) + + if not os.path.exists(output_math_dir): + os.makedirs(output_math_dir) + + if not os.path.exists(output_char_dir): + os.makedirs(output_char_dir) + + pool = Pool(processes=32) + pool.map(generate_subimages, training_pdf_names_list) + pool.close() + pool.join() diff --git a/ScanSSD/gtdb/remove_rect.py b/ScanSSD/gtdb/remove_rect.py new file mode 100644 index 0000000000000000000000000000000000000000..6be67c8c8afeea0505f3392964712e00ba6259da --- /dev/null +++ b/ScanSSD/gtdb/remove_rect.py @@ -0,0 +1,96 @@ +# Author: Parag Mali +# This file contains functions to remove rectangles +# that are inside other rectangles + +import sys +sys.path.extend(['/home/psm2208/code', '/home/psm2208/code']) +import cv2 +import os +import csv +import numpy as np +import utils.visualize as visualize +from multiprocessing import Pool +from cv2.dnn import NMSBoxes +from scipy.ndimage.measurements import label +import scipy.ndimage as ndimage +import copy +import shutil +from gtdb import box_utils + +def remove(args): + + try: + output_dir, pdf_name, page_num, page_math = args + + valid = [True] * page_math.shape[0] + + for i, m1 in enumerate(page_math): + for j, m2 in enumerate(page_math): + if i!=j and box_utils.check_inside(m1, m2): + valid[i] = False + break + + final_math = page_math[valid] + + math_file = open(os.path.join(output_dir, pdf_name + ".csv"), 'a') + writer = csv.writer(math_file, delimiter=",") + + for math_region in final_math: + math_region = math_region.tolist() + math_region.insert(0, page_num) + writer.writerow(math_region) + + print("Saved ", os.path.join(output_dir, pdf_name + ".csv"), " > ", page_num) + print('Before ', len(page_math), '--> after ', len(final_math)) + + except: + print("Exception while processing ", pdf_name, " ", page_num, " ", sys.exc_info()[0]) + + +def remove_rect(filename, math_dir, output_dir): + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + pages_list = [] + pdf_names = open(filename, 'r') + + for pdf_name in pdf_names: + print('Processing-1', pdf_name) + pdf_name = pdf_name.strip() + + if pdf_name != '': + math_file = os.path.join(math_dir, pdf_name + ".csv") + math_regions = np.genfromtxt(math_file, delimiter=',', dtype=float) + + pages = np.unique(math_regions[:, 0]) + + for page_num in pages: + + page_math = math_regions[np.where(math_regions[:,0]==page_num)] + page_math = page_math[:,1:] + pages_list.append([output_dir, pdf_name, page_num, page_math]) + + pdf_names.close() + + pool = Pool(processes=4) + pool.map(remove, pages_list) + pool.close() + pool.join() + + +if __name__ == "__main__": + home_data = "/home/psm2208/data/GTDB/" + home_eval = "/home/psm2208/code/eval/" + home_images = "/home/psm2208/data/GTDB/images/" + home_anno = "/home/psm2208/data/GTDB/annotations/" + + math_dir = "/home/psm2208/code/eval/tt_samsung" #"/home/psm2208/data/GTDB/relations_train_adjust_csv" + output_dir = "/home/psm2208/code/eval/tt_samsung_removed" #"/home/psm2208/data/GTDB/relations_train_adjust_csv_removed" + + type = sys.argv[1] + + remove_rect(home_data + type, math_dir, output_dir) diff --git a/ScanSSD/gtdb/resize_gt.py b/ScanSSD/gtdb/resize_gt.py new file mode 100644 index 0000000000000000000000000000000000000000..03acae7794a6e9e3a17431b7a55d55d5e5003d61 --- /dev/null +++ b/ScanSSD/gtdb/resize_gt.py @@ -0,0 +1,219 @@ +# Author: Parag Mali +# This script resizes ground truth based on the given final width and height +# NOTE: It is no longer needed, as ground truth modification is done in the data loader + +# read the image +import numpy as np +import cv2 +import copy +import os +import sys +from multiprocessing import Pool + +# Default parameters for thr GTDB dataset +final_width = 512 +final_height = 512 + +math_dir='/home/psm2208/data/GTDB/annotations/' +char_dir='/home/psm2208/data/GTDB/char_annotations/' +image_dir='/home/psm2208/data/GTDB/images/' +output_image_dir='/home/psm2208/data/GTDB/processed_images_/' +output_math_dir='/home/psm2208/data/GTDB/processed_annotations_/' +output_char_dir='/home/psm2208/data/GTDB/processed_char_annotations_/' + +# This function generates resized gt +def resize_gt(pdf_name ='Alford94'): + + # find all the images + image_filelist = [file for file in os.listdir(os.path.join(image_dir, pdf_name)) if file.endswith('.png')] + + # math annotations + math_filepath = os.path.join(math_dir, pdf_name + ".math") + math_file_present = os.path.isfile(math_filepath) + + # char annotations + char_filepath = os.path.join(char_dir, pdf_name + ".char") + char_file_present = os.path.isfile(char_filepath) + + if math_file_present: + math_file = open(math_filepath, 'r') + + boxes = {} + + for line in math_file: + box = line.split(",") + idx = int(box[0]) + 1 + box = box[1:] + + box = list(map(int, box)) + + if idx not in boxes: + boxes[idx] = [] + + boxes[idx].append(box) + + if char_file_present: + char_file = open(char_filepath, 'r') + + char_boxes = {} + + for line in char_file: + char_box = line.split(",") + idx = int(char_box[0]) + 1 + char_box = char_box[2:] + + #box = list(map(int, box)) + + if idx not in char_boxes: + char_boxes[idx] = [] + + char_boxes[idx].append(char_box) + + + for image_filepath in image_filelist: + + #os.path.basename + image = cv2.imread(os.path.join(image_dir, pdf_name, image_filepath)) + basename = os.path.basename(image_filepath) + page_id = int(os.path.splitext(basename)[0]) + + original_width = image.shape[1] + original_height = image.shape[0] + + resized_image = cv2.imread(os.path.join(resized_image_dir, pdf_name, image_filepath)) + intermediate_width = resized_image.shape[1] + intermediate_height = resized_image.shape[0] + + intermediate_width_ratio = intermediate_width / original_width + intermediate_height_ratio = intermediate_height / original_height + + final_width_ratio = final_width * intermediate_width_ratio / intermediate_width + final_height_ratio = final_height * intermediate_height_ratio / intermediate_height + + final_image = cv2.resize(resized_image, (final_height, final_width)) + + if math_file_present: + if page_id in boxes: + current_boxes = boxes[page_id] + else: + current_boxes = [] + + # preprocess the boxes + for box in current_boxes: + + box[0] = int(np.round(box[0] * final_width_ratio)) + box[1] = int(np.round(box[1] * final_height_ratio)) + box[2] = int(np.round(box[2] * final_width_ratio)) + box[3] = int(np.round(box[3] * final_height_ratio)) + + #cv2.rectangle(final_image, (box[0], box[1]), (box[2], box[3]), (255,0,0)) + + #cv2.imwrite("/home/psm2208/test.png", final_image) + + if char_file_present: + if page_id in char_boxes: + current_char_boxes = char_boxes[page_id] + else: + current_char_boxes = [] + + # preprocess the boxes + for box in current_char_boxes: + + box[0] = int(np.round(float(box[0]) * final_width_ratio)) + box[1] = int(np.round(float(box[1]) * final_height_ratio)) + box[2] = int(np.round(float(box[2]) * final_width_ratio)) + box[3] = int(np.round(float(box[3]) * final_height_ratio)) + + + # create required dirs + if not os.path.exists(os.path.join(output_image_dir, pdf_name)): + os.makedirs(os.path.join(output_image_dir, pdf_name)) + + if not os.path.exists(os.path.join(output_math_dir, pdf_name)): + os.makedirs(os.path.join(output_math_dir, pdf_name)) + + if not os.path.exists(os.path.join(output_char_dir, pdf_name)): + os.makedirs(os.path.join(output_char_dir, pdf_name)) + + print('Processing image : ', pdf_name, "/", page_id) + + # save the final image + cv2.imwrite(os.path.join(output_image_dir, pdf_name, str(page_id) + ".png"), final_image) + + if math_file_present: + out_math_file = os.path.join(output_math_dir, pdf_name, str(page_id) + ".pmath") + out_math = open(out_math_file, "w") + + for box in current_boxes: + out_math.write(','.join(str(x) for x in box) + "\n") + + out_math.close() + + if char_file_present: + out_char_file = os.path.join(output_char_dir, pdf_name, str(page_id) + ".pchar") + out_char = open(out_char_file, "w") + + for box in current_char_boxes: + out_char.write(','.join(str(x) for x in box) + "\n") + + out_char.close() + + +# check if two rectangles intersect +def intersects(first, other): + return not (first[2] < other[0] or + first[0] > other[2] or + first[1] > other[3] or + first[3] < other[1]) + +if __name__ == '__main__': + + training_pdf_names = open(sys.argv[1], 'r') # train_pdf + + training_pdf_names_list = [] + + # for each training image pdf file + for pdf_name in training_pdf_names: + pdf_name = pdf_name.strip() + if pdf_name != '': + training_pdf_names_list.append(pdf_name) + + if not os.path.exists(os.path.join(output_image_dir, pdf_name)): + os.makedirs(os.path.join(output_image_dir, pdf_name)) + + if not os.path.exists(os.path.join(output_math_dir, pdf_name)): + os.makedirs(os.path.join(output_math_dir, pdf_name)) + + if not os.path.exists(os.path.join(output_char_dir, pdf_name)): + os.makedirs(os.path.join(output_char_dir, pdf_name)) + + training_pdf_names.close() + + size = "512" + dpi = "150" + + suffix = dpi + + char_dir = '/home/psm2208/data/GTDB/char_annotations/' + math_dir = '/home/psm2208/data/GTDB/annotations/' + image_dir = '/home/psm2208/data/GTDB/images/' + + resized_image_dir = "/home/psm2208/data/GTDB/resized_images_" + suffix + output_image_dir = '/home/psm2208/data/GTDB/processed_images_' + suffix + output_math_dir = '/home/psm2208/data/GTDB/processed_annotations_' + suffix + output_char_dir = '/home/psm2208/data/GTDB/processed_char_annotations_' + suffix + + # create required dirs + if not os.path.exists(output_image_dir): + os.makedirs(output_image_dir) + + if not os.path.exists(output_math_dir): + os.makedirs(output_math_dir) + + if not os.path.exists(output_char_dir): + os.makedirs(output_char_dir) + + pool = Pool(processes=24) + pool.map(resize_gt, training_pdf_names_list) + pool.close() + pool.join() diff --git a/ScanSSD/gtdb/scale_boxes.py b/ScanSSD/gtdb/scale_boxes.py new file mode 100644 index 0000000000000000000000000000000000000000..7e669dc4807544eca1bd05f1ef172c85e4847042 --- /dev/null +++ b/ScanSSD/gtdb/scale_boxes.py @@ -0,0 +1,138 @@ +# Author: Parag Mali +# This script scales the detected boxes back to given width and height of the page + +# read the image +import sys +sys.path.extend(['/home/psm2208/code', '/home/psm2208/code']) +import cv2 +import os +import numpy as np +from multiprocessing import Pool +from gtdb import feature_extractor +import argparse + +def parse_args(): + # Default parameters + + parser = argparse.ArgumentParser( + description='Stitching method') + + parser.add_argument('--data_file', default='test', + type=str, help='choose one') + parser.add_argument('--output_dir', default='.', + help='Output directory path') + parser.add_argument('--math_dir', required=True, + type=str, help='detections dir') + parser.add_argument('--math_ext', default='.csv', + help='Extention of detection files') + parser.add_argument('--home_data', default='/home/psm2208/data/GTDB/', type = str, + help='database dir') + parser.add_argument('--home_eval', default='/home/psm2208/code/eval/', type = str, + help='Eval dir') + parser.add_argument('--home_images', default='/home/psm2208/data/GTDB/images/', type = str, + help='Images dir') + parser.add_argument('--home_anno', default='/home/psm2208/data/GTDB/annotations/', type = str, + help='Annotations dir') + parser.add_argument('--home_char', default='/home/psm2208/data/GTDB/char_annotations/', type = str, + help='Char anno dir') + parser.add_argument('--num_workers', default=4, type=int, help='Number of workers') + parser.add_argument('--type', default='math', type=str, help='Math or text') + + return parser.parse_args() + +def read_math(args, pdf_name): + + math_file = os.path.join(args.math_dir, pdf_name + args.math_ext) + data = np.array([]) + + if os.path.exists(math_file): + data = np.genfromtxt(math_file, delimiter=',') + + # if there is only one entry convert it to correct form required + if len(data.shape) == 1: + data = data.reshape(1, -1) + + return data.astype(int) + +def scale(params): + + try: + args, math_regions, pdf_name, page_num = params + print('Processing ', pdf_name, ' > ', page_num) + + image = cv2.imread(os.path.join(args.home_images, pdf_name, str(int(page_num + 1)) + ".png")) + + height = image.shape[0] + width = image.shape[1] + + new_math = [] + + for math in math_regions: + box = [0,0,0,0] + box[0] = math[0] * width / 512 # 512 is ip image size + box[1] = math[1] * height / 512 + box[2] = math[2] * width / 512 + box[3] = math[3] * height / 512 + + if feature_extractor.width(box) > 0 and feature_extractor.height(box) > 0: + new_math.append(box) + + return new_math + except Exception as e: + print('Error while processing ', pdf_name, ' > ', page_num, sys.exc_info()) + return [] + +def scale_boxes(args): + pdf_list = [] + pdf_names_file = open(args.data_file, 'r') + + for pdf_name in pdf_names_file: + pdf_name = pdf_name.strip() + + if pdf_name != '': + pdf_list.append(pdf_name) + + regions = {} + + for pdf_name in pdf_list: + regions[pdf_name] = read_math(args, pdf_name) + + voting_ip_list = [] + for pdf_name in pdf_list: + + pages = np.unique(regions[pdf_name][:, 0]) + + #args, math_regions, pdf_name, page_num + for page_num in pages: + current_math = regions[pdf_name][regions[pdf_name][:,0] == page_num] + + current_math = np.delete(current_math, 0, 1) + + voting_ip_list.append([args, current_math, pdf_name, page_num]) + + pool = Pool(processes=args.num_workers) + + out = pool.map(scale, voting_ip_list) + + for ip, final_math in zip(voting_ip_list, out): + pdf_name = ip[2] + page_num = ip[3] + + col = np.array([int(page_num)] * len(final_math)) + final_math = np.concatenate((col[:, np.newaxis], final_math), axis=1) + + math_file_path = os.path.join(args.output_dir, pdf_name + '.csv') + + if not os.path.exists(os.path.dirname(math_file_path)): + os.makedirs(os.path.dirname(math_file_path)) + + math_file = open(math_file_path, 'a') + + np.savetxt(math_file, final_math, fmt='%.2f', delimiter=',') + math_file.close() + + +if __name__ == '__main__': + + args = parse_args() + scale_boxes(args) diff --git a/ScanSSD/gtdb/split_annotations_per_page.py b/ScanSSD/gtdb/split_annotations_per_page.py new file mode 100644 index 0000000000000000000000000000000000000000..772cc9600ba477d8f31957369ab897ba8cddb615 --- /dev/null +++ b/ScanSSD/gtdb/split_annotations_per_page.py @@ -0,0 +1,112 @@ +# Author: Parag Mali +# This script generates page level annotations from the PDF level annotations +# provided in the dataset + +import sys +import os +from multiprocessing import Pool +import csv +import cv2 + +def split(args): + + gt_dir, pdf_name, out_dir, ext = args + + file_path = os.path.join(gt_dir, pdf_name + "." + ext) + img_dir = '/home/psm2208/data/GTDB/images/' + + # create a map of page to list of math boxes + map = {} + + if ext == "math": + + file_ip = open(file_path, "r") + for line in file_ip: + entries = line.strip().split(",") + + # if entry is not in map + if entries[0] not in map: + map[entries[0]] = [] + + map[entries[0]].append(entries[1:]) + + for key in map: + + boxes = map[key] + key = float(key) + img_file = os.path.join(img_dir, pdf_name, str(int(key) + 1) + ".png") + img = cv2.imread(img_file) + + height, width, channels = img.shape + + #width_ratio = 512 / width + #height_ratio = 512 / height + + width_ratio = 1 + height_ratio = 1 + + # create processed math file + file_op = open(os.path.join(out_dir, pdf_name, str(int(key) + 1)) + ".p" + ext, "w") + + for box in boxes: + # xmin, ymin, xmax, ymax + + box[0] = float(box[0]) * width_ratio + box[1] = float(box[1]) * height_ratio + box[2] = float(box[2]) * width_ratio + box[3] = float(box[3]) * height_ratio + + file_op.write(','.join(str(e) for e in box) + "\n") + + file_op.close() + file_ip.close() + + elif ext == "char": + with open(file_path, 'r') as csvfile: + reader = csv.reader(csvfile, delimiter=',') + for row in reader: + # if entry is not in map + if row[0] not in map: + map[row[0]] = [] + + map[row[0]].append(row) + + for key in map: + + boxes = map[key] + + with open(os.path.join(out_dir, pdf_name, str(int(key) + 1)) + ".p" + ext, "w") as csvfile: + writer = csv.writer(csvfile, delimiter=',') + + for box in boxes: + writer.writerow(box) + +def test(): + + filename = sys.argv[1] # file names to be processed + out_dir = sys.argv[2] # output dir + gt_dir = sys.argv[3] # gt dir + ext = sys.argv[4] # file extension + + pdf_names_list = [] + pdf_names = open(filename, 'r') + + for pdf_name in pdf_names: + pdf_name = pdf_name.strip() + + if not os.path.exists(os.path.join(out_dir, pdf_name)): + os.mkdir(os.path.join(out_dir, pdf_name)) + + if pdf_name != '': + pdf_names_list.append((gt_dir, pdf_name, out_dir, ext)) + + pdf_names.close() + + pool = Pool(processes=32) + pool.map(split, pdf_names_list) + pool.close() + pool.join() + + +if __name__ == "__main__": + test() diff --git a/ScanSSD/gtdb/stitch_patches_page.py b/ScanSSD/gtdb/stitch_patches_page.py new file mode 100644 index 0000000000000000000000000000000000000000..68f973c6ca3c9ae5261f5fd07db23a85d675c1de --- /dev/null +++ b/ScanSSD/gtdb/stitch_patches_page.py @@ -0,0 +1,884 @@ +# Author: Parag Mali +# This script stitches back the output generated on the image patches (sub-images) +# Note: It works from page level detection results. +# read the image +import sys +sys.path.extend(['/home/psm2208/code', '/home/psm2208/code']) +import cv2 +import os +import csv +import numpy as np +import utils.visualize as visualize +from multiprocessing import Pool +from cv2.dnn import NMSBoxes +from scipy.ndimage.measurements import label +import scipy.ndimage as ndimage +import copy +from gtdb import fit_box +from gtdb import box_utils +from gtdb import feature_extractor +import shutil +import time +from sklearn.cluster import AgglomerativeClustering + +# Default parameters for thr GTDB dataset +intermediate_width = 4800 +intermediate_height = 6000 + +crop_size = 1800 #TODO + +final_width = -1 +final_height = -1 +if_visualize = -1 +projections = -1 + +stride = 0.1 + +n_horizontal = int(intermediate_width / crop_size) # 4 +n_vertical = int(intermediate_height / crop_size) # 5 +algorithm = 'equal' + +def read_math_regions(args): + + image, pdf_name, page_num, math_files_list = args + + original_width = image.shape[1] + original_height = image.shape[0] + + intermediate_width_ratio = original_width / intermediate_width + intermediate_height_ratio = original_height / intermediate_height + + annotations_map = {} + + for math_file in math_files_list: + + name = math_file.split(os.sep)[-1] + + if os.stat(math_file).st_size == 0: + continue + + data = np.genfromtxt(math_file, delimiter=',') + + # if there is only one entry convert it to correct form required + if len(data.shape) == 1: + data = data.reshape(1, -1) + + annotations_map[name] = data + + + h = np.arange(0, n_horizontal - 1 + stride, stride) + v = np.arange(0, n_vertical - 1 + stride, stride) + + for filename in annotations_map: + + data_arr = annotations_map[filename] + patch_num = int(filename.split("_")[-1].split(".csv")[0]) + + x_offset = h[(patch_num - 1) % len(h)] + y_offset = v[int((patch_num - 1) / len(h))] + + if data_arr is None: + continue + + # find scaling factors + final_width_ratio = crop_size / final_width + final_height_ratio = crop_size / final_height + + data_arr[:, 0] = data_arr[:, 0] * final_width_ratio + data_arr[:, 2] = data_arr[:, 2] * final_width_ratio + data_arr[:, 1] = data_arr[:, 1] * final_height_ratio + data_arr[:, 3] = data_arr[:, 3] * final_height_ratio + + data_arr[:, 0] = data_arr[:, 0] + x_offset * crop_size + data_arr[:, 2] = data_arr[:, 2] + x_offset * crop_size + data_arr[:, 1] = data_arr[:, 1] + y_offset * crop_size + data_arr[:, 3] = data_arr[:, 3] + y_offset * crop_size + + data_arr[:, 0] = data_arr[:, 0] * intermediate_width_ratio + data_arr[:, 2] = data_arr[:, 2] * intermediate_width_ratio + data_arr[:, 1] = data_arr[:, 1] * intermediate_height_ratio + data_arr[:, 3] = data_arr[:, 3] * intermediate_height_ratio + + # multiply score by 100. Because later we convert data_arr to int datatype + data_arr[:, 4] = data_arr[:, 4] * 100 + + annotations_map[filename] = data_arr + + math_regions = np.array([]) + + for key in annotations_map: + + if len(math_regions) == 0: + math_regions = annotations_map[key][:, :] + else: + math_regions = np.concatenate((math_regions, annotations_map[key]), axis=0) + + math_regions = math_regions.astype(int) + math_regions = math_regions[(-math_regions[:, 4]).argsort()] + + return math_regions + +def read_char_data(char_filepath): + + # Read char data + if char_filepath != "": + char_data = np.genfromtxt(char_filepath, delimiter=',') + char_data = char_data[:, 2:6] + + # if there is only one entry convert it to correct form required + if len(char_data.shape) == 1: + char_data = char_data.reshape(1, -1) + + else: + char_data = [] + + return char_data + +def read_gt_regions(gt_dir, pdf_name, page_num): + + gt_regions = None + + if os.path.isfile(os.path.join(gt_dir, pdf_name, page_num + ".pmath")): + gt_path = os.path.join(gt_dir, pdf_name, page_num + ".pmath") + + try: + gt_regions = np.genfromtxt(gt_path, delimiter=',') + gt_regions = gt_regions.astype(int) + + # if there is only one entry convert it to correct form required + if len(gt_regions.shape) == 1: + gt_regions = gt_regions.reshape(1, -1) + + gt_regions = gt_regions.tolist() + + except: + gt_regions = None + + return gt_regions + + +def combine_math_regions(args): + + """ + It is called for each page in the pdf + :param math_files_list: + :param image_path: + :param output_image: + :return: + """ + pdf_name, page_num, math_files_list, char_filepath, image_path, output_image, \ + gt_dir, thresh, output_dir = args + + try: + image = cv2.imread(image_path) + + math_regions = read_math_regions((image, pdf_name, page_num, math_files_list)) + char_data = read_char_data(char_filepath) + + # intital math regions + math_regions_initial = np.copy(math_regions) + + processed_math_regions = np.copy(math_regions) + + # This will give final math regions + math_regions = voting_algo(math_regions, char_data, image, pdf_name, page_num, + output_dir, algorithm=algorithm, thresh_votes=thresh) + math_regions = np.reshape(math_regions, (-1,4)) + gt_regions = read_gt_regions(gt_dir, pdf_name, page_num) + + if not os.path.exists(os.path.dirname(output_image)): + os.mkdir(os.path.dirname(output_image)) + + if if_visualize == 1: + visualize.draw_all_boxes(image, processed_math_regions, math_regions, gt_regions, output_image) + + col = np.array([int(page_num) - 1] * math_regions.shape[0]) + math_regions = np.concatenate((col[:, np.newaxis], math_regions), axis=1) + + math_file = open(os.path.join(output_dir, pdf_name + '.csv'), 'a') + np.savetxt(math_file, math_regions, fmt='%.2f', delimiter=',') + math_file.close() + + except: + print("Exception while processing ", pdf_name, " ", page_num, " ", sys.exc_info()) + + return math_regions + +def preprocess_math_regions(math_regions, image): + + im_bw = convert_to_binary(image) + + args = [] + + for box in math_regions: + args.append((im_bw, box)) + #preprocessed_math_regions.append(box) + + pool = Pool(processes=1) + preprocessed_math_regions = pool.map(fit_box.adjust_box_p, args) + pool.close() + pool.join() + + return preprocessed_math_regions + + +def fusion(args): + + pdf_name, page_num, output_dir, math_cache, alpha, beta, gamma = args + + #equal_votes = voting_equal(votes, math_regions) + math_regions = np.copy(math_cache) + + # get rid of all boxes which are less than alpha confident + #math_regions = math_regions[math_regions[:,-1]>(alpha*100)] + + #inter_math = box_utils.find_intersecting_boxes(math_regions) + +# math_regions = math_regions.tolist() + + # intersection of math boxes changes on the fly + # as they grow if fused with other boxes + + # iteratively fuse boxes + previous_len = len(math_regions) + + while True: + math_regions = fuse(math_regions, alpha, beta, gamma) + current_len = len(math_regions) + + if current_len == previous_len: + break + + previous_len = current_len + + op_dir = os.path.join(output_dir, 'fusion_' + str("{:.1f}".format(alpha)) + '_' + + str("{:.1f}".format(beta)) + '_' + str("{:.1f}".format(gamma))) + + if not os.path.exists(op_dir): + os.mkdir(op_dir) + + col = np.array([int(page_num) - 1] * math_regions.shape[0]) + math_regions = np.concatenate((col[:, np.newaxis], math_regions), axis=1) + + math_file = open(os.path.join(op_dir, pdf_name + '.csv'), 'a') + np.savetxt(math_file, math_regions, fmt='%.2f', delimiter=',') + math_file.close() + + + #TODO: Remove the last column from math_regions i.e confidence column + return math_regions + + +def fuse(math_regions, alpha, beta, gamma): + + final_math = [] + removed = set(np.argwhere(math_regions[:,-1]<(alpha*100)).flatten()) + + for key in range(len(math_regions)): + + if key not in removed: + box1 = math_regions[key] + + for j in range(len(math_regions[key+1:])): + v = key+1+j + if key not in removed and v not in removed: + box2 = math_regions[v] + + # if IOU > beta, merge + if feature_extractor.iou(box1, box2) > beta: + box1 = box_utils.merge(box1, box2) + removed.add(v) + + # if inclusion > gamma, remove + elif feature_extractor.inclusion(box1, box2) > gamma: + removed.add(key) + elif feature_extractor.inclusion(box2, box1) > gamma: + removed.add(v) + + if key not in removed: + math_regions[key][:4] = box1[:4] + + + #writer = csv.writer(math_file, delimiter=",") + count = 0 + keep = [] + + for math_region in math_regions: + if count not in removed: + keep.append(True) + else: + keep.append(False) + count = count + 1 + + math_regions = math_regions[keep] + #col = np.full((1, math_regions.shape[0]), ) + return math_regions + +def voting_equal(votes, math_regions): + # cast votes for the regions + for box in math_regions: + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = \ + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] + 1 + + return votes + +def voting_avg_score(votes, math_regions): + + counts = np.zeros(shape=votes.shape) + + # cast votes for the regions + for box in math_regions: + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = \ + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] + box[4] + + # count the regions + for box in math_regions: + counts[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = \ + counts[int(box[1]):int(box[3]), int(box[0]):int(box[2])] + 1 + + # To avoid divide by zero + # When counts is zero, votes will be zero + # So this should not affect the calculations and results + counts[counts == 0] = 1 + + votes = votes / counts + + return votes + +def voting_sum_score(votes, math_regions): + + # cast votes for the regions + for box in math_regions: + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = \ + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] + box[4] + + return votes + + +# def voting_sum_score(votes, boundary_scores, math_regions): +# +# # cast votes for the regions +# for box in math_regions: +# votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = \ +# votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] + box[4] +# +# boundary_scores[int(box[1]), int(box[0]):int(box[2])] = \ +# boundary_scores[int(box[1]), int(box[0]):int(box[2])] + box[4] +# boundary_scores[int(box[3]), int(box[0]):int(box[2])] = \ +# boundary_scores[int(box[3]), int(box[0]):int(box[2])] + box[4] +# +# return votes, boundary_scores + +def voting_heuristic_score(votes, math_regions): + + # All the connected components should have equal score + + # Threshold on the score + + # Threshold on the votes + pass + +def voting_max_score(votes, math_regions): + + # sort based on the confs. Confs is column 4 + data = math_regions[math_regions[:, 4].argsort()] + + for box in data: + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = box[4] + + return votes + +def vote_for_regions(math_regions, image, algorithm, thresh_votes): + + original_width = image.shape[1] + original_height = image.shape[0] + + votes = np.zeros(shape=(original_height, original_width)) + #boundary_scores = np.zeros(shape=(original_height, original_width)) + + if algorithm == 'sum_score': + thresh_votes = thresh_votes * 100 + votes = voting_sum_score(votes, math_regions) + elif algorithm == 'max_score': + votes = voting_max_score(votes, math_regions) + elif algorithm == 'avg_score': + thresh_votes = thresh_votes * 100 + votes = voting_avg_score(votes, math_regions) + else: # algorithm='equal' + votes = voting_equal(votes, math_regions) + + #cv2.imwrite('/home/psm2208/votes.png', votes*255/np.max(votes)) + + # find the regions with higher than the threshold votes + # change all the values less than thresh_votes to 0 + votes[votes < thresh_votes] = 0 + votes[votes >= thresh_votes] = 1 + + #cv2.imwrite('/home/psm2208/votes_bw.png', votes*255) + + return votes + +def label_regions(math_regions, image): + labeled = np.zeros(image.shape[:2]) + + math_regions = math_regions[math_regions[:, 4].argsort()] + + for label, math_region in enumerate(math_regions): + labeled[math_region[1]:math_region[3], math_region[0]:math_region[2]] = label + + #uniq_labels = np.unique(labeled) + + return labeled + +def area(box): + + w = box[3] - box[1] + h = box[2] - box[0] + + return w*h + +def char_algo(math_regions, char_data, image, algorithm='equal', thresh_votes=20): + + if len(char_data) == 0: + return [] + + # vote for the regions + votes = vote_for_regions(math_regions, image, algorithm, thresh_votes) + + # Check if character is math or not + char_data = char_data.tolist() + + for char_box in char_data: + #print('nz ', np.count_nonzero(votes[int(char_box[1]):int(char_box[3]), int(char_box[0]):int(char_box[2])])) + + if np.count_nonzero(votes[int(char_box[1]):int(char_box[3]), int(char_box[0]):int(char_box[2])]) > 100: + char_box.append(1) # APPEND 1 to indicate that it is a math character + else: + char_box.append(0) + + # TODO Find the regions + + boxes = [] + + box = [] + + for char_box in char_data: + + if char_box[-1] == 1: + if len(box) == 0: + box = copy.deepcopy(char_box[:4]) + continue + + nbox = copy.deepcopy(box) + nbox[0] = min(char_box[0], box[0]) # left + nbox[1] = min(char_box[1], box[1]) # top + nbox[2] = max(char_box[2], box[2]) # left + width + nbox[3] = max(char_box[3], box[3]) # top + height + + if area(nbox) > 4 * area(box): + boxes.append(box) + box = copy.deepcopy(char_box[:4]) + else: + box = nbox + else: + if len(box) != 0: + boxes.append(box) + box = [] + + if len(box) != 0: + boxes.append(box) + + return boxes + + +def clustering(math_regions, char_data, image, algorithm, thresh_votes): + + centers = [] + for math_region in math_regions: + center = [(math_region[0]+math_region[2])/2, (math_region[1]+math_region[3])/2] + centers.append(center) + + clustering = AgglomerativeClustering().fit(centers) + + labels = np.unique(clustering.labels_) + + for label in labels: + regions = math_regions[labels==label] + + pass + + +def voting_algo(math_regions, char_data, image, pdf_name, page_num, + output_dir, algorithm='equal', thresh_votes=20): + + if algorithm == 'char_algo': + return char_algo(math_regions, char_data, image, algorithm, thresh_votes) + + if algorithm == 'clustering': + return clustering(math_regions, char_data, image, algorithm, thresh_votes) + + # vote for the regions + votes = vote_for_regions(math_regions, image, algorithm, thresh_votes) + + if projections == 1: + votes[rows_with_at_least_k_black_pixels(image)] = 0 + + im_bw = convert_to_binary(image) + structure = np.ones((3, 3), dtype=np.int) + labeled, ncomponents = label(votes, structure) + + # found the boxes. Now extract the co-ordinates left,top,right,bottom + boxes = [] + indices = np.indices(votes.shape).T[:, :, [1, 0]] + + for i in range(ncomponents): + + labels = (labeled == (i+1)) + pixels = indices[labels.T] + + if len(pixels) < 1: + continue + + box = [min(pixels[:, 0]), min(pixels[:, 1]), max(pixels[:, 0]), max(pixels[:, 1])] + + # expansion to correctly fit the region + box = fit_box.adjust_box(im_bw, box) + + # if box has 0 width or height, do not add it in the final detections + if feature_extractor.width(box) < 1 or feature_extractor.height(box) < 1: + continue + + boxes.append(box) + + return boxes + + +def convert_to_binary(image): + + gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + im_bw = np.zeros(gray_image.shape) + im_bw[gray_image > 127] = 0 + im_bw[gray_image <= 127] = 1 + + return im_bw + +def find_blank_rows_h(image): + + gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + im_bw = np.zeros(gray_image.shape) + im_bw[gray_image > 127] = 0 + im_bw[gray_image <= 127] = 1 + + row_sum = np.sum(im_bw, axis=1) + + cum_sum = np.zeros(row_sum.shape) + + cum_sum[0] = row_sum[0] + + for i, sum in enumerate(row_sum[1:]): + cum_sum[i+1] = cum_sum[i] + sum + + blank_rows = [] + for i, sum in enumerate(cum_sum): + if is_blank(cum_sum, i): + blank_rows.append(i) + + return blank_rows + +# check n last rows +def is_blank(cum_sum, current, n=30, thresh=3000): + + # It is not a blank row + ret = False + + # check below + if (current < len(cum_sum)) and (cum_sum[current] - cum_sum[current-1]) == 0: + + b_thresh = thresh + + if current + n >= len(cum_sum): + val = cum_sum[len(cum_sum)-1] - cum_sum[current] + b_thresh = (thresh/n) * (len(cum_sum) - current) + else: + val = cum_sum[current + n] - cum_sum[current] + + # It is a blank row + if val >= b_thresh: + ret = True + + return ret + +def rows_with_at_least_k_black_pixels(image, k=10): + + im_bw = convert_to_binary(image) # characters are black + rows = im_bw.sum(axis=1) + return np.where(rows<=k)[0] + + +def find_blank_rows(image, line_spacing=1): + + gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + blank_rows = np.all(gray_image == 255, axis=1) + + im_bw = np.zeros(gray_image.shape) + im_bw[blank_rows] = 255 + #gray_image[~blank_rows] = 0 + + #cv2.imwrite("/home/psm2208/code/eval/test.png", im_bw) + + labeled, ncomponents = ndimage.label(im_bw) + rows = [] + + indices = np.indices(im_bw.shape).T[:, :, [1, 0]] + + line_bbs = ndimage.find_objects(labeled) + sizes = np.array([[bb.stop - bb.start for bb in line_bb] + for line_bb in line_bbs]) + + sizes = sizes[:,0] + mask = (sizes > line_spacing) + + idx = np.flatnonzero(mask) + + for i in idx: + labels = (labeled == (i+1)) + pixels = indices[labels.T] + box = [min(pixels[:, 0]), min(pixels[:, 1]), max(pixels[:, 0]), max(pixels[:, 1])] + rows.append(box) + + return rows + +def perform_nms(math_regions): + + # convert from x1,y1,x2,y2 to x,y,w,h + math_regions[:, 2] = math_regions[:, 2] - math_regions[:, 0] + math_regions[:, 3] = math_regions[:, 3] - math_regions[:, 1] + + scores = math_regions[:, 4] + math_regions = np.delete(math_regions, 4, 1) + + math_regions = math_regions.tolist() + scores = scores.tolist() + + indices = NMSBoxes(math_regions, scores, 0.2, 0.5) + + indices = [item for sublist in indices for item in sublist] + math_regions = [math_regions[i] for i in indices] + + math_regions = np.array(math_regions) + + # restore to x1,y1,x2,y2 + math_regions[:, 2] = math_regions[:, 2] + math_regions[:, 0] + math_regions[:, 3] = math_regions[:, 3] + math_regions[:, 1] + + return math_regions.tolist() + +def overlap_expand(math_regions): + + print('Number of math regions ', len(math_regions)) + + if type(math_regions) != type([]): + math_regions = math_regions.tolist() + + obsolete = [] + + for i in range(len(math_regions)): + for j in range(i+1, len(math_regions)): + # print(i,j) + if box_utils.intersects(math_regions[i], math_regions[j]): + math_regions[i][0] = min(math_regions[i][0], math_regions[j][0]) + math_regions[i][1] = min(math_regions[i][1], math_regions[j][1]) + math_regions[i][2] = max(math_regions[i][2], math_regions[j][2]) + math_regions[i][3] = max(math_regions[i][3], math_regions[j][3]) + obsolete.append(j) + + math_regions = [i for j, i in enumerate(math_regions) if j not in obsolete] + + return math_regions + + +def read_page_info(filename, annotations_dir, image_dir, gt_dir, char_gt): + + #annotations_dir is dir of detections for sub-images + + pages_list = [] + pdf_names = open(filename, 'r') + + annotations_map = {} + char_annotations_map = {} + + for pdf_name in pdf_names: + pdf_name = pdf_name.strip() + + if pdf_name != '': + + if pdf_name not in annotations_map: + annotations_map[pdf_name] = {} + + for root, dirs, _ in os.walk(os.path.join(annotations_dir, pdf_name), topdown=False): + + for dir in dirs: + for filename in os.listdir(os.path.join(annotations_dir, pdf_name, dir)): + + if filename.endswith(".csv") or filename.endswith(".pmath"): + patch_num = os.path.splitext(filename)[0] + page_num = os.path.basename(os.path.join(annotations_dir, pdf_name, dir)) + + if page_num not in annotations_map[pdf_name]: + annotations_map[pdf_name][page_num] = [] + + annotations_map[pdf_name][page_num].append( + os.path.join(annotations_dir, pdf_name, dir, filename)) + + if pdf_name not in char_annotations_map: + char_annotations_map[pdf_name] = {} + + for filename in os.listdir(os.path.join(char_gt, pdf_name)): + + if filename.endswith(".csv") or filename.endswith(".pchar"): + page_num = os.path.splitext(filename)[0] + + char_annotations_map[pdf_name][page_num] = \ + os.path.join(char_gt, pdf_name, filename) + + for root, dirs, files in os.walk(os.path.join(char_gt, pdf_name)): + for name in files: + if name.endswith(".pchar"): + page_num = os.path.splitext(name)[0] + if page_num in annotations_map[pdf_name]: + image = cv2.imread(os.path.join(image_dir, pdf_name, page_num + '.png')) + pages_list.append((image, pdf_name, page_num, annotations_map[pdf_name][page_num])) + + pdf_names.close() + return pages_list, annotations_map, char_annotations_map + + +def stitch_patches(filename, annotations_dir, output_dir, + image_dir='/home/psm2208/data/GTDB/images/', + gt_dir="/home/psm2208/data/GTDB/", char_gt="", thresh=20): + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + pages_list, annotations_map, char_annotations_map = \ + read_page_info(filename, annotations_dir, image_dir, gt_dir, char_gt) + + pooling_list = [] + + for i, page in enumerate(pages_list): + pdf_name = page[1] + page_num = page[2] + pooling_list.append(( + pdf_name, + page_num, + annotations_map[pdf_name][page_num], + char_annotations_map[pdf_name][page_num], + os.path.join(image_dir, pdf_name, page_num + '.png'), + os.path.join(output_dir, pdf_name, page_num + '.png'), + gt_dir, + thresh, + output_dir)) + + pool = Pool(processes=32) + total = str(len(pooling_list)) + + start = time.time() + init = start + + for i, _ in enumerate(pool.imap_unordered(combine_math_regions, pooling_list), 1): + print('\nprogress: ' + str(i) + '/' + total) + if i%100==0: + current = time.time() + print('\nTime taken for last 100, total time:', current-start, current-init) + start = time.time() + + pool.close() + pool.join() + + +def fusion_stitch_grid(filename, annotations_dir, output_dir, + image_dir='/home/psm2208/data/GTDB/images/', + gt_dir="/home/psm2208/data/GTDB/", char_gt="", thresh=20): + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + pages_list = read_page_info(filename, annotations_dir, image_dir, gt_dir, char_gt) + + # find math regions + pool = Pool(processes=32) + total = str(len(pages_list)) + math_cache = pool.map(read_math_regions, pages_list) + + pool.close() + pool.join() + + fusion_list = [] + + for i, page in enumerate(pages_list): + pdf_name = page[1] + page_num = page[2] + for a in np.arange(0.3, 1.1, 0.1): + for b in np.arange(0.0, 1.1, 0.1): + for c in np.arange(0.0, 1.1, 0.1): + fusion_list.append((pdf_name, + page_num, + output_dir, + #inter_math[i], + math_cache[i], + #0.7,0.2,0.2)) + a,b,c)) + + pool = Pool(processes=32) + total = str(len(fusion_list)) + #pool.map(fusion, fusion_list) + start = time.time() + init = start + + for i, _ in enumerate(pool.imap_unordered(fusion, fusion_list), 1): + print('\nprogress: ' + str(i) + '/' + total) + if i%100==0: + current = time.time() + print('\nTime taken for last 100, total time:', current-start, current-init) + start = time.time() + + pool.close() + pool.join() + + +if __name__ == '__main__': + + # TODO: use argparser + stride = 0.1 + thresh = float(sys.argv[1]) # 30 + algorithm = sys.argv[2] # equal + type = sys.argv[3] # train_pdf + dir_to_eval = sys.argv[4] # Test3_Focal_10_25 + + if len(sys.argv) > 5: + if_visualize = int(sys.argv[5]) # visualize + projections = int(sys.argv[6]) # projections + else: + visualize = 0 + projections = 0 + + final_width = 512 + final_height = 512 + + home_data = "/home/psm2208/data/GTDB/" + home_eval = "/home/psm2208/code/eval/" + home_images = "/home/psm2208/data/GTDB/images/" + home_anno = "/home/psm2208/data/GTDB/annotations/" + home_char = "/home/psm2208/data/GTDB/char_annotations/" + + stitch_patches(home_data + type, home_eval + dir_to_eval, + home_eval + dir_to_eval + "/" + algorithm + "_" + str(thresh), + home_images, home_anno, home_char, thresh) diff --git a/ScanSSD/gtdb/stitch_patches_pdf.py b/ScanSSD/gtdb/stitch_patches_pdf.py new file mode 100644 index 0000000000000000000000000000000000000000..48beabfc69a68208416a22c114751beadb9a1b5e --- /dev/null +++ b/ScanSSD/gtdb/stitch_patches_pdf.py @@ -0,0 +1,294 @@ +# Author: Parag Mali +# This script stitches back the output generated on the image patches (sub-images) +# This works from PDF level detection results + +# read the image +import sys +sys.path.extend(['/home/psm2208/code', '/home/psm2208/code']) +#sys.path.extend(['/media/psm2208/Workspace/ssd', '/media/psm2208/Workspace/ssd']) +import cv2 +import os +import numpy as np +from multiprocessing import Pool +from cv2.dnn import NMSBoxes +from scipy.ndimage.measurements import label +from gtdb import fit_box +from gtdb import feature_extractor +import argparse +import shutil + +# Default parameters for thr GTDB dataset +def parse_args(): + + parser = argparse.ArgumentParser( + description='Stitching method') + + parser.add_argument('--data_file', default='test', + type=str, help='choose one') + parser.add_argument('--output_dir', default='.', + help='Output directory path') + parser.add_argument('--math_dir', required=True, + type=str, help='detections dir') + parser.add_argument('--math_ext', default='.csv', + help='Extention of detection files') + parser.add_argument('--home_data', default='/home/psm2208/data/GTDB/', type = str, + help='database dir') + parser.add_argument('--home_eval', default='/home/psm2208/code/eval/', type = str, + help='Eval dir') + parser.add_argument('--home_images', default='/home/psm2208/data/GTDB/images/', type = str, + help='Images dir') + parser.add_argument('--home_anno', default='/home/psm2208/data/GTDB/annotations/', type = str, + help='Annotations dir') + parser.add_argument('--home_char', default='/home/psm2208/data/GTDB/char_annotations/', type = str, + help='Char anno dir') + parser.add_argument('--stitching_algo', default='equal', type=str, help='Stitching algo to use') + parser.add_argument('--algo_threshold', default=30, type=int, help='Stitching algo threshold') + parser.add_argument('--num_workers', default=4, type=int, help='Number of workers') + parser.add_argument('--preprocess', type=bool, help='Whether to fit math regions before pooling') + parser.add_argument('--postprocess', type=bool, help='Whether to fit math regions after pooling') + + return parser.parse_args() + +def read_math(args, pdf_name): + + math_file = os.path.join(args.math_dir, pdf_name + args.math_ext) + data = np.array([]) + + if os.path.exists(math_file): + data = np.genfromtxt(math_file, delimiter=',') + + # if there is only one entry convert it to correct form required + if len(data.shape) == 1: + data = data.reshape(1, -1) + + return data + +def vote_for_regions(args, math_regions, image): + + original_width = image.shape[1] + original_height = image.shape[0] + thresh_votes = args.algo_threshold + + votes = np.zeros(shape=(original_height, original_width)) + + if args.stitching_algo == 'sum_score': + votes = voting_sum_score(votes, math_regions) + elif args.stitching_algo == 'max_score': + votes = voting_max_score(votes, math_regions) + elif args.stitching_algo == 'avg_score': + votes = voting_avg_score(votes, math_regions) + else: # algorithm='equal' + votes = voting_equal(votes, math_regions) + + # find the regions with higher than the threshold votes + # change all the values less than thresh_votes to 0 + votes[votes < thresh_votes] = 0 + votes[votes >= thresh_votes] = 1 + + return votes + +def voting_sum_score(votes, math_regions): + + # cast votes for the regions + for box in math_regions: + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = \ + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] + box[4] + + return votes + +def voting_max_score(votes, math_regions): + + # sort based on the confs. Confs is column 4 + data = math_regions[math_regions[:, 4].argsort()] + + for box in data: + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = box[4] + + return votes + +def voting_equal(votes, math_regions): + # cast votes for the regions + for box in math_regions: + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = \ + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] + 1 + + return votes + +def voting_avg_score(votes, math_regions): + + counts = np.zeros(shape=votes.shape) + + # cast votes for the regions + for box in math_regions: + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = \ + votes[int(box[1]):int(box[3]), int(box[0]):int(box[2])] + box[4] + + # count the regions + for box in math_regions: + counts[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = \ + counts[int(box[1]):int(box[3]), int(box[0]):int(box[2])] + 1 + + # To avoid divide by zero + # When counts is zero, votes will be zero + # So this should not affect the calculations and results + counts[counts == 0] = 1 + + votes = votes / counts + + return votes + + +def perform_nms(math_regions): + + # convert from x1,y1,x2,y2 to x,y,w,h + math_regions[:, 2] = math_regions[:, 2] - math_regions[:, 0] + math_regions[:, 3] = math_regions[:, 3] - math_regions[:, 1] + + scores = math_regions[:, 4] + math_regions = np.delete(math_regions, 4, 1) + + math_regions = math_regions.tolist() + scores = scores.tolist() + + indices = NMSBoxes(math_regions, scores, 0.2, 0.5) + + indices = [item for sublist in indices for item in sublist] + math_regions = [math_regions[i] for i in indices] + + math_regions = np.array(math_regions) + + # restore to x1,y1,x2,y2 + math_regions[:, 2] = math_regions[:, 2] + math_regions[:, 0] + math_regions[:, 3] = math_regions[:, 3] + math_regions[:, 1] + + return math_regions.tolist() + +def preprocess_math_regions(math_regions, image): + + im_bw = convert_to_binary(image) + + preprocessed_math_regions = [] + + for box in math_regions: + box = fit_box.adjust_box(im_bw, box) + preprocessed_math_regions.append(box) + + return preprocessed_math_regions + + +def voting_algo(params): + + args, math_regions, pdf_name, page_num = params + print('Processing ', pdf_name, ' > ', page_num) + + image = cv2.imread(os.path.join(args.home_images,pdf_name,str(int(page_num+1))+".png")) + + if args.preprocess: + math_regions = preprocess_math_regions(math_regions, image) + + # vote for the regions + votes = vote_for_regions(args, math_regions, image) + + im_bw = convert_to_binary(image) + structure = np.ones((3, 3), dtype=np.int) + labeled, ncomponents = label(votes, structure) + + # found the boxes. Now extract the co-ordinates left,top,right,bottom + boxes = [] + indices = np.indices(votes.shape).T[:, :, [1, 0]] + + for i in range(ncomponents): + + labels = (labeled == (i+1)) + pixels = indices[labels.T] + + if len(pixels) < 1: + continue + + box = [min(pixels[:, 0]), min(pixels[:, 1]), max(pixels[:, 0]), max(pixels[:, 1])] + + if args.postprocess: + # expansion to correctly fit the region + box = fit_box.adjust_box(im_bw, box) + + # if box has 0 width or height, do not add it in the final detections + if feature_extractor.width(box) < 1 or feature_extractor.height(box) < 1: + continue + + boxes.append(box) + + return boxes + + +def stitch(args): + pdf_list = [] + pdf_names_file = open(args.data_file, 'r') + + if os.path.exists(args.output_dir): + shutil.rmtree(args.output_dir) + + for pdf_name in pdf_names_file: + pdf_name = pdf_name.strip() + + if pdf_name != '': + pdf_list.append(pdf_name) + + math_regions = {} + + for pdf_name in pdf_list: + math_regions[pdf_name] = read_math(args, pdf_name) + + voting_ip_list = [] + for pdf_name in pdf_list: + + pages = np.unique(math_regions[pdf_name][:, 0]) + + #args, math_regions, pdf_name, page_num + for page_num in pages: + current_math = math_regions[pdf_name][math_regions[pdf_name][:,0] == page_num] + voting_ip_list.append([args, np.delete(current_math, 0, 1), pdf_name, page_num]) + + pool = Pool(processes=args.num_workers) + out = pool.map(voting_algo, voting_ip_list) + + for ip, final_math in zip(voting_ip_list, out): + + try: + pdf_name = ip[2] + page_num = ip[3] + + if len(final_math) == 0: + continue + + col = np.array([int(page_num)] * len(final_math)) + final_math = np.concatenate((col[:, np.newaxis], final_math), axis=1) + + math_file_path = os.path.join(args.output_dir, pdf_name + '.csv') + + if not os.path.exists(os.path.dirname(math_file_path)): + os.makedirs(os.path.dirname(math_file_path)) + + math_file = open(math_file_path, 'a') + + np.savetxt(math_file, final_math, fmt='%.2f', delimiter=',') + math_file.close() + except Exception as e: + print("Exception while processing ", pdf_name, " ", page_num, " ", sys.exc_info(), e) + +def convert_to_binary(image): + + gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + im_bw = np.zeros(gray_image.shape) + im_bw[gray_image > 127] = 0 + im_bw[gray_image <= 127] = 1 + + return im_bw + + +if __name__ == '__main__': + + # TODO: use argparser + args = parse_args() + print('Using : ', args) + stitch(args) diff --git a/ScanSSD/layers/__init__.py b/ScanSSD/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53a3f4b5160995d93bc7911e808b3045d74362c9 --- /dev/null +++ b/ScanSSD/layers/__init__.py @@ -0,0 +1,2 @@ +from .functions import * +from .modules import * diff --git a/ScanSSD/layers/box_utils.py b/ScanSSD/layers/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d5cfe02fd19e32bed3b81319ded39a32fe6a2b --- /dev/null +++ b/ScanSSD/layers/box_utils.py @@ -0,0 +1,266 @@ +# -*- coding: utf-8 -*- +import torch + + +def point_form(boxes): + """ Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin + boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax + + +def center_size(boxes): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy + boxes[:, 2:] - boxes[:, :2], 1) # w, h + + +def intersect(box_a, box_b): + """ We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), + box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), + box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2]-box_a[:, 0]) * + (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2]-box_b[:, 0]) * + (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when mathing boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + loc_t: (tensor) Tensor to be filled w/ endcoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location and 2)confidence preds. + """ + # jaccard index + overlaps = jaccard( + truths, + point_form(priors) + ) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] + conf = labels[best_truth_idx] + 1 # Shape: [num_priors] + conf[best_truth_overlap < threshold] = 0 # label as background + loc = encode(matches, priors, variances) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + + # When g_wh goes to 0, -inf is returned and we get inf loss in the end + # to avoid that, replacing 0 with 0.0000001 + g_wh[g_wh == 0] = 1e-10 + + g_wh = torch.log(g_wh) / variances[1] + + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.45, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = scores.new(scores.size(0)).zero_().long() + count = 0 + + if boxes.numel() == 0: + return keep, count + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + xx1.resize_(0) + yy1.resize_(0) + xx2.resize_(0) + yy2.resize_(0) + + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + + # xx1 = torch.gather(x1, 0, idx.unsqueeze(1).repeat(1, x1.shape[1])) + # yy1 = torch.gather(y1, 0, idx.unsqueeze(1).repeat(1, y1.shape[1])) + # xx2 = torch.gather(x2, 0, idx.unsqueeze(1).repeat(1, x2.shape[1])) + # yy2 = torch.gather(y2, 0, idx.unsqueeze(1).repeat(1, y2.shape[1])) + + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w*h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter/union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count + + +# TEST +if "__name__" == "__main__": + boxes = np.array([]) + scores = np.array([]) + abs = nms(boxes=boxes,scores=scores) + pass \ No newline at end of file diff --git a/ScanSSD/layers/functions/__init__.py b/ScanSSD/layers/functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79555601e6d891e58393af1464d338d5ed786b72 --- /dev/null +++ b/ScanSSD/layers/functions/__init__.py @@ -0,0 +1,5 @@ +from .detection import Detect +from .prior_box import PriorBox + + +__all__ = ['Detect', 'PriorBox'] diff --git a/ScanSSD/layers/functions/detection.py b/ScanSSD/layers/functions/detection.py new file mode 100644 index 0000000000000000000000000000000000000000..597aa7374786ec7400a76db3f3d8a5c46e92da5b --- /dev/null +++ b/ScanSSD/layers/functions/detection.py @@ -0,0 +1,68 @@ +import torch +from torch.autograd import Function +from ..box_utils import decode, nms + +class Detect(Function): + """At test time, Detect is the final layer of SSD. Decode location preds, + apply non-maximum suppression to location predictions based on conf + scores and threshold to a top_k number of output predictions for both + confidence score and locations. + """ + def __init__(self, cfg, num_classes, bkg_label, top_k, conf_thresh, nms_thresh): + self.num_classes = num_classes + self.background_label = bkg_label + self.top_k = top_k + # Parameters used in nms. + self.nms_thresh = nms_thresh + if nms_thresh <= 0: + raise ValueError('nms_threshold must be non negative.') + self.conf_thresh = conf_thresh + self.variance = cfg['variance'] + + # @staticmethod + def forward(self, loc_data, conf_data, prior_data): + """ + Args: + loc_data: (tensor) Loc preds from loc layers + Shape: [batch,num_priors*4] + conf_data: (tensor) Shape: Conf preds from conf layers + Shape: [batch*num_priors,num_classes] + prior_data: (tensor) Prior boxes and variances from priorbox layers + Shape: [1,num_priors,4] + """ + # move to CPU + loc_data = loc_data.cpu() + conf_data = conf_data.cpu() + prior_data = prior_data.cpu() + + num = loc_data.size(0) # batch size + num_priors = prior_data.size(0) + output = torch.zeros(num, self.num_classes, self.top_k, 5) + conf_preds = conf_data.view(num, num_priors, + self.num_classes).transpose(2, 1) + + # Decode predictions into bboxes. + for i in range(num): + decoded_boxes = decode(loc_data[i], prior_data, self.variance) + # For each class, perform nms + conf_scores = conf_preds[i].clone() + #print('decoded boxes ', decoded_boxes) + #print('conf scores', conf_scores) + for cl in range(1, self.num_classes): + c_mask = conf_scores[cl].gt(self.conf_thresh) + scores = conf_scores[cl][c_mask] + if scores.dim() == 0: + continue + l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes) + boxes = decoded_boxes[l_mask].view(-1, 4) + # idx of highest scoring and non-overlapping boxes per class + + ids, count = nms(boxes, scores, self.nms_thresh, self.top_k) + output[i, cl, :count] = \ + torch.cat((scores[ids[:count]].unsqueeze(1), + boxes[ids[:count]]), 1) + flt = output.contiguous().view(num, -1, 5) + _, idx = flt[:, :, 0].sort(1, descending=True) + _, rank = idx.sort(1) + flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0) + return output, boxes, scores diff --git a/ScanSSD/layers/functions/prior_box.py b/ScanSSD/layers/functions/prior_box.py new file mode 100644 index 0000000000000000000000000000000000000000..47e4c4eff036724211ae56aea558ae0e49704b52 --- /dev/null +++ b/ScanSSD/layers/functions/prior_box.py @@ -0,0 +1,62 @@ +from __future__ import division +from math import sqrt as sqrt +from itertools import product as product +import torch + + +class PriorBox(object): + """Compute priorbox coordinates in center-offset form for each source + feature map. + """ + def __init__(self, args, cfg): + super().__init__() + self.image_size = cfg['min_dim'] + # number of priors for feature map location (either 4 or 6) + self.num_priors = len(cfg['aspect_ratios']) + self.variance = cfg['variance'] or [0.1] + self.feature_maps = cfg['feature_maps'] + self.min_sizes = cfg['min_sizes'] + self.max_sizes = cfg['max_sizes'] + self.steps = cfg['steps'] + self.aspect_ratios = cfg['aspect_ratios'] + self.clip = cfg['clip'] + self.version = cfg['name'] + self.is_vertical_prior_boxes_enabled = cfg['is_vertical_prior_boxes_enabled'] + self.args = args + for v in self.variance: + if v <= 0: + raise ValueError('Variances must be greater than 0') + + def forward(self): + mean = [] + + for k, f in enumerate(self.feature_maps): + for i, j in product(range(f), repeat=2): + f_k = self.image_size / self.steps[k] + + # unit center x,y + cx = (j + 0.5) / f_k + cy = (i + 0.5) / f_k + + # aspect_ratio: 1 + # rel size: min_size + s_k = self.min_sizes[k]/self.image_size + mean += [cx, cy, s_k, s_k] + + # aspect_ratio: 1 + # rel size: sqrt(s_k * s_(k+1)) + s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size)) + mean += [cx, cy, s_k_prime, s_k_prime] + + # rest of aspect ratios + for ar in self.aspect_ratios[k]: + mean += [cx, cy, s_k*sqrt(ar), s_k/sqrt(ar)] + + if self.is_vertical_prior_boxes_enabled: + mean += [cx, cy, s_k/sqrt(ar), s_k*sqrt(ar)] + + # back to torch land + output = torch.Tensor(mean).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output diff --git a/ScanSSD/layers/modules/__init__.py b/ScanSSD/layers/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..529c3c7956e9c576bd41122cd85fbe0cfaa1478a --- /dev/null +++ b/ScanSSD/layers/modules/__init__.py @@ -0,0 +1,5 @@ +from .l2norm import L2Norm +from .multibox_loss import MultiBoxLoss +from .focal_loss import FocalLoss + +__all__ = ['L2Norm', 'MultiBoxLoss', 'FocalLoss'] diff --git a/ScanSSD/layers/modules/focal_loss.py b/ScanSSD/layers/modules/focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe4002bd1eb5eae4c18829295198593b8aa654f --- /dev/null +++ b/ScanSSD/layers/modules/focal_loss.py @@ -0,0 +1,43 @@ +import torch + +import torch.nn as nn +import torch.nn.functional as F + +# from torch.autograd import Variable + + +class FocalLoss(nn.Module): + + def __init__(self, focusing_param=2, balance_param=0.25): #TODO try changing balance_param + super(FocalLoss, self).__init__() + + self.focusing_param = focusing_param + self.balance_param = balance_param + + def compute(self, output, target): + + logpt = - F.cross_entropy(output, target, reduction='sum') + pt = torch.exp(logpt) + + focal_loss = -((1 - pt) ** self.focusing_param) * logpt + + balanced_focal_loss = self.balance_param * focal_loss + + return balanced_focal_loss + + +# def test_focal_loss(): +# loss = FocalLoss() + +# input = Variable(torch.randn(3, 5), requires_grad=True) +# target = Variable(torch.LongTensor(3).random_(5)) + +# print(input) +# print(target) + +# output = loss(input, target) +# print(output) +# output.backward() + +if __name__=='__main__': + test_focal_loss() diff --git a/ScanSSD/layers/modules/l2norm.py b/ScanSSD/layers/modules/l2norm.py new file mode 100644 index 0000000000000000000000000000000000000000..d5eaa721634478f53479df7d30684227cb7a60ac --- /dev/null +++ b/ScanSSD/layers/modules/l2norm.py @@ -0,0 +1,24 @@ +import torch +import torch.nn as nn +# from torch.autograd import Function +# from torch.autograd import Variable +import torch.nn.init as init + +class L2Norm(nn.Module): + def __init__(self,n_channels, scale): + super(L2Norm,self).__init__() + self.n_channels = n_channels + self.gamma = scale or None + self.eps = 1e-10 + self.weight = nn.Parameter(torch.Tensor(self.n_channels)) + self.reset_parameters() + + def reset_parameters(self): + init.constant_(self.weight,self.gamma) + + def forward(self, x): + norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps + #x /= norm + x = torch.div(x,norm) + out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x + return out diff --git a/ScanSSD/layers/modules/multibox_loss.py b/ScanSSD/layers/modules/multibox_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e42816bed4ef662beb4d6e906bf58af70b1fbe47 --- /dev/null +++ b/ScanSSD/layers/modules/multibox_loss.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F +# from torch.autograd import Variable +from ..box_utils import match, log_sum_exp +from .focal_loss import FocalLoss + +class MultiBoxLoss(nn.Module): + """SSD Weighted Loss Function + Compute Targets: + 1) Produce Confidence Target Indices by matching ground truth boxes + with (default) 'priorboxes' that have jaccard index > threshold parameter + (default threshold: 0.5). + 2) Produce localization target by 'encoding' variance into offsets of ground + truth boxes and their matched 'priorboxes'. + 3) Hard negative mining to filter the excessive number of negative examples + that comes with using a large number of default bounding boxes. + (default negative:positive ratio 3:1) + Objective Loss: + L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N + Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss + weighted by α which is set to 1 by cross val. + Args: + c: class confidences, + l: predicted boxes, + g: ground truth boxes + N: number of matched default boxes + See: https://arxiv.org/pdf/1512.02325.pdf for more details. + """ + + def __init__(self, args, cfg, overlap_thresh, bkg_label, neg_pos): + + super(MultiBoxLoss, self).__init__() + self.args = args + self.num_classes = cfg['num_classes'] + self.threshold = overlap_thresh + self.background_label = bkg_label + self.negpos_ratio = neg_pos + self.variance = cfg['variance'] + self.focal_loss = FocalLoss() + # self.neg_overlap = neg_overlap + # self.encode_target = encode_target + # self.use_prior_for_matching = prior_for_matching + # self.do_neg_mining = args.neg_mining + + def forward(self, predictions, targets): + """Multibox Loss + Args: + predictions (tuple): A tuple containing loc preds, conf preds, + and prior boxes from SSD net. + conf shape: torch.size(batch_size,num_priors,num_classes) + loc shape: torch.size(batch_size,num_priors,4) + priors shape: torch.size(num_priors,4) + + targets (tensor): Ground truth boxes and labels for a batch, + shape: [batch_size,num_objs,5] (last idx is the label). + """ + loc_data, conf_data, priors = predictions + num = loc_data.size(0) + priors = priors[:loc_data.size(1), :] + num_priors = (priors.size(0)) + num_classes = self.num_classes + + # match priors (default boxes) and ground truth boxes + loc_t = torch.Tensor(num, num_priors, 4) + conf_t = torch.LongTensor(num, num_priors) + + for idx in range(num): + truths = targets[idx][:, :-1].data + labels = targets[idx][:, -1].data + defaults = priors.data + match(self.threshold, truths, defaults, self.variance, labels, + loc_t, conf_t, idx) + + if self.args.cuda: + loc_t = loc_t.cuda() + conf_t = conf_t.cuda() + + # wrap targets + loc_t = Variable(loc_t, requires_grad=False) + conf_t = Variable(conf_t, requires_grad=False) + + pos = conf_t > 0 + num_pos = pos.sum(dim=1, keepdim=True) + + # Localization Loss (Smooth L1) + # Shape: [batch,num_priors,4] + pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) + loc_p = loc_data[pos_idx].view(-1, 4) + loc_t = loc_t[pos_idx].view(-1, 4) + loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') + + # Compute max conf across batch for hard negative mining + batch_conf = conf_data.view(-1, self.num_classes) + + #print('conf_t view ', conf_t.view(-1, 1)) + #print('conf_t ' + conf_t.view(-1, 1)) + + loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) + + # Hard Negative Mining + if self.args.neg_mining: + loss_c = loss_c.view(pos.size()[0], pos.size()[1]) + loss_c = loss_c.view(num, -1) + loss_c[pos] = 0 # filter out pos boxes for now + _, loss_idx = loss_c.sort(1, descending=True) + _, idx_rank = loss_idx.sort(1) + num_pos = pos.long().sum(1, keepdim=True) + + num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) + neg = idx_rank < num_neg.expand_as(idx_rank) + else: + #num_neg = torch.tensor(0).expand_as(idx_rank) + #num_neg[idx_rank] = 1 + neg = conf_t == 0 + + # Confidence Loss Including Positive and Negative Example + pos_idx = pos.unsqueeze(2).expand_as(conf_data) + neg_idx = neg.unsqueeze(2).expand_as(conf_data) + conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes) + targets_weighted = conf_t[(pos+neg).gt(0)] + + if self.args.loss_fun == 'ce': + loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') + else: + loss_c = self.focal_loss.compute(conf_p, targets_weighted) + + # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N + + N = num_pos.data.sum() + #loss_l = loss_l.double() + #loss_c = loss_c.double() + loss_l /= N + loss_c /= N + return loss_l, loss_c diff --git a/ScanSSD/ssd.py b/ScanSSD/ssd.py new file mode 100644 index 0000000000000000000000000000000000000000..12bf4af4a49a4c3b1d0a02a1f3ebcf4dc1fe62f4 --- /dev/null +++ b/ScanSSD/ssd.py @@ -0,0 +1,220 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +# from torch.autograd import Variable +from .layers import * +import os +import numpy as np + +base = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 512, 512, 512] + +class SSD(nn.Module): + """Single Shot Multibox Architecture + The network is composed of a base VGG network followed by the + added multibox conv layers. Each multibox layer branches into + 1) conv2d for class conf scores + 2) conv2d for localization predictions + 3) associated priorbox layer to produce default bounding + boxes specific to the layer's feature map size. + See: https://arxiv.org/pdf/1512.02325.pdf for more details. + + Args: + phase: (string) Can be "test" or "train" + size: input image size + base: VGG16 layers for input, size of either 300 or 500 + extras: extra layers that feed to multibox loc and conf layers + head: "multibox head" consists of loc and conf conv layers + """ + + def __init__(self, args, phase, cfg, size, base, extras, head, num_classes, gpu_id): + super(SSD, self).__init__() + self.phase = phase + self.num_classes = num_classes + self.cfg = cfg#(coco, voc)[num_classes == 21] + self.priorbox = PriorBox(args, self.cfg) + #self.priors = Variable(self.priorbox.forward(), volatile=True) + with torch.no_grad(): + self.priors = self.priorbox.forward() + # check if cuda is available + if torch.cuda.is_available(): + self.priors.to(gpu_id) + + self.size = size + + # SSD network + self.vgg = nn.ModuleList(base) + # Layer learns to scale the l2 normalized features from conv4_3 + self.L2Norm = L2Norm(512, 20) + self.extras = nn.ModuleList(extras) + + self.loc = nn.ModuleList(head[0]) + self.conf = nn.ModuleList(head[1]) + + if phase == 'test': + self.softmax = nn.Softmax(dim=-1) + self.detect = Detect(cfg, num_classes, 0, 200, 0.01, 0.45) + #self.detect = Detect(cfg, num_classes, 0, 1000000, 0.01, 1.00) + + + def forward(self, x): + """Applies network layers and ops on input image(s) x. + + Args: + x: input image or batch of images. Shape: [batch,3,300,300]. + + Return: + Depending on phase: + test: + Variable(tensor) of output class label predictions, + confidence score, and corresponding location predictions for + each object detected. Shape: [batch,topk,7] + + train: + list of concat outputs from: + 1: confidence layers, Shape: [batch*num_priors,num_classes] + 2: localization layers, Shape: [batch,num_priors*4] + 3: priorbox layers, Shape: [2,num_priors*4] + """ + sources = list() + loc = list() + conf = list() + + # apply vgg up to conv4_3 relu + for k in range(23): + x = self.vgg[k](x) + + s = self.L2Norm(x) + sources.append(s) + + # apply vgg up to fc7 + for k in range(23, len(self.vgg)): + x = self.vgg[k](x) + + sources.append(x) + + # apply extra layers and cache source layer outputs + for k, v in enumerate(self.extras): + x = F.relu(v(x), inplace=True) + if k % 2 == 1: + sources.append(x) + + # apply multibox head to source layers + for (x, l, c) in zip(sources, self.loc, self.conf): + loc.append(l(x).permute(0, 2, 3, 1).contiguous()) + conf.append(c(x).permute(0, 2, 3, 1).contiguous()) + + loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) + conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) + + if self.phase == "test": + output, boxes, scores = self.detect.forward( + loc.view(loc.size(0), -1, 4), # loc preds + self.softmax(conf.view(conf.size(0), -1, + self.num_classes)), # conf preds + self.priors.type(type(x.data)) # default boxes + ) + + return output, boxes.detach(), scores.detach() + else: + output = ( + loc.view(loc.size(0), -1, 4), + conf.view(conf.size(0), -1, self.num_classes), + self.priors + ) + + return output + + + def load_weights(self, base_file): + other, ext = os.path.splitext(base_file) + if ext == '.pkl' or '.pth': + print('Loading weights into state dict...') + self.load_state_dict(torch.load(base_file, + map_location=lambda storage, loc: storage)) + print('Finished!') + else: + print('Sorry only .pth and .pkl files supported.') + + +# This function is derived from torchvision VGG make_layers() +# https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py +def vgg(cfg, i, batch_norm=False): + layers = [] + in_channels = i + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + elif v == 'C': + layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) + conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) + conv7 = nn.Conv2d(1024, 1024, kernel_size=1) + layers += [pool5, conv6, + nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] + + return layers + + +def add_extras(cfg, size, i, batch_norm=False): + # Extra layers added to VGG for feature scaling + layers = [] + in_channels = i + flag = False + + extras = cfg['extras'][str(size)] + + for k, v in enumerate(extras): + if in_channels != 'S': + if v == 'S': + layers += [nn.Conv2d(in_channels, extras[k + 1], + kernel_size=(1, 3)[flag], stride=2, padding=1)] + else: + layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] + flag = not flag + in_channels = v + + if size == 512: + layers.append(nn.Conv2d(in_channels, 128, kernel_size=1, stride=1)) + layers.append(nn.Conv2d(128, 256, kernel_size=4, stride=1, padding=1)) + + return layers + + +def multibox(args, vgg, extra_layers, cfg, size, num_classes): + loc_layers = [] + conf_layers = [] + + vgg_source = [21, -2] + + for k, v in enumerate(vgg_source): + loc_layers += [nn.Conv2d(vgg[v].out_channels, cfg[k] * 4, + kernel_size=args.kernel, padding=args.padding)] + conf_layers += [nn.Conv2d(vgg[v].out_channels, cfg[k] * num_classes, + kernel_size=args.kernel, padding=args.padding)] + + for k, v in enumerate(extra_layers[1::2], 2): + loc_layers += [nn.Conv2d(v.out_channels, cfg[k] * 4, + kernel_size=args.kernel, padding=args.padding)] + conf_layers += [nn.Conv2d(v.out_channels, cfg[k] * num_classes, + kernel_size=args.kernel, padding=args.padding)] + + return vgg, extra_layers, (loc_layers, conf_layers) + + +def build_ssd(args, phase, cfg, gpu_id, size=300, num_classes=21): + if phase != "test" and phase != "train": + print("ERROR: Phase: " + phase + " not recognized") + return + + base_, extras_, head_ = multibox(args,vgg(base, 3, False), + add_extras(cfg, size, 1024), + cfg['mbox'][str(size)], size, num_classes) + + return SSD(args, phase, cfg, size, base_, extras_, head_, num_classes, gpu_id) diff --git a/ScanSSD/test.py b/ScanSSD/test.py new file mode 100644 index 0000000000000000000000000000000000000000..9c8feb4c542166e46f047c10d21affc9e08f7903 --- /dev/null +++ b/ScanSSD/test.py @@ -0,0 +1,202 @@ +''' +This file contains functions to test and save the results +''' +from __future__ import print_function +import os +import argparse +import torch.backends.cudnn as cudnn +from ssd import build_ssd +from utils import draw_boxes, helpers, save_boxes +import logging +import time +import datetime +from torch.autograd import Variable +from torchvision import datasets, transforms +from torch.utils.data import Dataset, DataLoader +from data import * +import shutil +import torch.nn as nn + +def test_net_batch(args, net, gpu_id, dataset, transform, thresh): + ''' + Batch testing + ''' + num_images = len(dataset) + + if args.limit != -1: + num_images = args.limit + + data_loader = DataLoader(dataset, args.batch_size, + num_workers=args.num_workers, + shuffle=False, collate_fn=detection_collate, + pin_memory=True) + + total = len(dataset) + + logging.debug('Test dataset size is {}'.format(total)) + + done = 0 + + for batch_idx, (images, targets, metadata) in enumerate(data_loader): + + done = done + len(images) + logging.debug('processing {}/{}'.format(done, total)) + + if args.cuda: + images = images.cuda() + targets = [ann.cuda() for ann in targets] + else: + images = Variable(images) + targets = [Variable(ann, volatile=True) for ann in targets] + # targets = [ann for ann in targets] + + y, debug_boxes, debug_scores = net(images) # forward pass + detections = y.data + + k = 0 + for img, meta in zip(images, metadata): + + img_id = meta[0] + x_l = meta[1] + y_l = meta[2] + + img = img.permute(1,2,0) + # scale each detection back up to the image + scale = torch.Tensor([img.shape[1], img.shape[0], + img.shape[1], img.shape[0]]) + + recognized_boxes = [] + recognized_scores = [] + + # [1,2,200,5] + # we only care about math class + # hence select detections[image_id, class, detection_id, detection_score] + # class=1 for math + i = 1 + j = 0 + + while j < detections.size(2) and detections[k, i, j, 0] >= thresh: # TODO it was 0.6 + + score = detections[k, i, j, 0] + pt = (detections[k, i, j, 1:] * args.window).cpu().numpy() + coords = (pt[0] + x_l, pt[1] + y_l, pt[2] + x_l, pt[3] + y_l) + #coords = (pt[0], pt[1], pt[2], pt[3]) + recognized_boxes.append(coords) + recognized_scores.append(score.cpu().numpy()) + + j += 1 + + save_boxes(args, recognized_boxes, recognized_scores, img_id) + k = k + 1 + + if args.verbose: + draw_boxes(args, img.cpu().numpy(), recognized_boxes, recognized_scores, + debug_boxes, debug_scores, scale, img_id) + +def test_gtdb(args): + + gpu_id = 0 + if args.cuda: + gpu_id = helpers.get_freer_gpu() + torch.cuda.set_device(gpu_id) + + # load net + num_classes = 2 # +1 background + + # initialize SSD + net = build_ssd(args, 'test', exp_cfg[args.cfg], gpu_id, args.model_type, num_classes) + + logging.debug(net) + net.to(gpu_id) + net = nn.DataParallel(net) + net.load_state_dict(torch.load(args.trained_model, map_location={'cuda:1':'cuda:0'})) + net.eval() + logging.debug('Finished loading model!') + + dataset = GTDBDetection(args, args.test_data, split='test', + transform=BaseTransform(args.model_type, (246,246,246)), + target_transform=GTDBAnnotationTransform()) + + if args.cuda: + net = net.to(gpu_id) + cudnn.benchmark = True + + # evaluation + test_net_batch(args, net, gpu_id, dataset, + BaseTransform(args.model_type, (246,246,246)), + thresh=args.visual_threshold) + +def parse_args(): + parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection') + parser.add_argument('--trained_model', default='weights/ssd300_GTDB_990.pth', + type=str, help='Trained state_dict file path to open') + parser.add_argument('--save_folder', default='eval/', type=str, + help='Dir to save results') + parser.add_argument('--visual_threshold', default=0.6, type=float, + help='Final confidence threshold') + parser.add_argument('--cuda', default=False, type=bool, + help='Use cuda to train model') + parser.add_argument('--dataset_root', default=GTDB_ROOT, help='Location of VOC root directory') + parser.add_argument('--test_data', default="testing_data", help='testing data file') + parser.add_argument('--verbose', default=False, type=bool, help='plot output') + parser.add_argument('--suffix', default="_10", type=str, help='suffix of directory of images for testing') + parser.add_argument('--exp_name', default="SSD", help='Name of the experiment. Will be used to generate output') + parser.add_argument('--model_type', default=300, type=int, + help='Type of ssd model, ssd300 or ssd512') + parser.add_argument('--use_char_info', default=False, type=bool, help='Whether or not to use char info') + parser.add_argument('--limit', default=-1, type=int, help='limit on number of test examples') + parser.add_argument('--cfg', default="gtdb", type=str, + help='Type of network: either gtdb or math_gtdb_512') + parser.add_argument('--batch_size', default=16, type=int, + help='Batch size for training') + parser.add_argument('--num_workers', default=4, type=int, + help='Number of workers used in data loading') + parser.add_argument('--kernel', default="3 3", type=int, nargs='+', + help='Kernel size for feature layers: 3 3 or 1 5') + parser.add_argument('--padding', default="1 1", type=int, nargs='+', + help='Padding for feature layers: 1 1 or 0 2') + parser.add_argument('--neg_mining', default=True, type=bool, + help='Whether or not to use hard negative mining with ratio 1:3') + parser.add_argument('--log_dir', default="logs", type=str, + help='dir to save the logs') + parser.add_argument('--stride', default=0.1, type=float, + help='Stride to use for sliding window') + parser.add_argument('--window', default=1200, type=int, + help='Sliding window size') + + parser.add_argument('-f', default=None, type=str, help="Dummy arg so we can load in Jupyter Notebooks") + + args = parser.parse_args() + + if args.cuda and torch.cuda.is_available(): + torch.set_default_tensor_type('torch.cuda.FloatTensor') + else: + torch.set_default_tensor_type('torch.FloatTensor') + + if not os.path.exists(args.save_folder): + os.mkdir(args.save_folder) + + if os.path.exists(os.path.join(args.save_folder, args.exp_name)): + shutil.rmtree(os.path.join(args.save_folder, args.exp_name)) + + return args + +if __name__ == '__main__': + + args = parse_args() + start = time.time() + try: + filepath=os.path.join(args.log_dir, args.exp_name + "_" + str(round(time.time())) + ".log") + print('Logging to ' + filepath) + logging.basicConfig(filename=filepath, + filemode='w', format='%(process)d - %(asctime)s - %(message)s', + datefmt='%d-%b-%y %H:%M:%S', level=logging.DEBUG) + + test_gtdb(args) + except Exception as e: + logging.error("Exception occurred", exc_info=True) + + end = time.time() + logging.debug('Toal time taken ' + str(datetime.timedelta(seconds=end-start))) + logging.debug("Testing done!") + diff --git a/ScanSSD/train.py b/ScanSSD/train.py new file mode 100644 index 0000000000000000000000000000000000000000..65c1b7d9288f9a64c4031fb2717f0a4ba442de56 --- /dev/null +++ b/ScanSSD/train.py @@ -0,0 +1,470 @@ +# Sample command +# python3 train.py --dataset GTDB --dataset_root /home/psm2208/data/GTDB/ +# --cuda True --visdom True --batch_size 16 --num_workers 8 --layers_to_freeze 0 +# --exp_name weights_1 --model_type 512 --suffix _512 --type processed_train_512 +# --cfg math_gtdb_512 --loss_fun fl --kernel 1 5 --padding 0 2 --neg_mining False + +from data import * +from utils.augmentations import SSDAugmentation +from layers.modules import MultiBoxLoss +from ssd import build_ssd +import os +import sys +import torch +from torch.autograd import Variable +import torch.nn as nn +import torch.optim as optim +import torch.backends.cudnn as cudnn +import torch.nn.init as init +import torch.utils.data as data +import argparse +from utils import helpers +import logging +import time +import datetime +from torchviz import make_dot + +def str2bool(v): + return v.lower() in ("yes", "true", "t", "1") + +def train(args): + + cfg = exp_cfg[args.cfg] + dataset = GTDBDetection(args, args.training_data, split='train', + transform=SSDAugmentation(cfg['min_dim'], mean=MEANS)) + + if args.visdom: + import visdom + viz = visdom.Visdom() + + gpu_id = 0 + + if args.cuda: + gpu_id = helpers.get_freer_gpu() + logging.debug('Using GPU with id ' + str(gpu_id)) + torch.cuda.set_device(gpu_id) + + ssd_net = build_ssd(args, 'train', cfg, gpu_id, cfg['min_dim'], cfg['num_classes']) + + logging.debug(ssd_net) + + ct = 0 + # freeze first few layers + for child in ssd_net.vgg.children(): + if ct >= args.layers_to_freeze: + break + + child.requires_grad = False + ct += 1 + + + if args.resume: + logging.debug('Resuming training, loading {}...'.format(args.resume)) + ssd_net.load_state_dict(torch.load(args.resume)) + else: + vgg_weights = torch.load("base_weights/" + args.basenet) + logging.debug('Loading base network...') + ssd_net.vgg.load_state_dict(vgg_weights) + + #visualize(ssd_net, gpu_id) + + # if args.cuda: + # net = net.cuda() + step_index = 0 + + if not args.resume: + logging.debug('Initializing weights...') + # initialize newly added layers' weights with xavier method + ssd_net.extras.apply(weights_init) + ssd_net.loc.apply(weights_init) + ssd_net.conf.apply(weights_init) + + for val in cfg['lr_steps']: + if args.start_iter > val: + step_index = step_index + 1 + + # Saving random initialized weights + torch.save(ssd_net.state_dict(), + os.path.join( + 'weights_' + args.exp_name, 'initial_' + str(args.model_type) + args.dataset + '.pth')) + + optimizer = optim.SGD(ssd_net.parameters(), lr=args.lr, momentum=args.momentum, + weight_decay=args.weight_decay) + + adjust_learning_rate(args, optimizer, args.gamma, step_index) + + #args, cfg, overlap_thresh, bkg_label, neg_pos + #criterion = MultiBoxLoss(args, cfg, 0.5, 0, 3) + criterion = MultiBoxLoss(args, cfg, args.pos_thresh, 0, 3) + + if args.cuda: + ssd_net = torch.nn.DataParallel(ssd_net) + # ssd_net = ssd_net.to(gpu_id) + cudnn.benchmark = True + + ssd_net.train() + + # loss counters + loc_loss = 0 + conf_loss = 0 + min_total_loss = float('inf') + epoch = 0 + logging.debug('Loading the dataset...') + + epoch_size = len(dataset) // args.batch_size + + logging.debug('Training SSD on:' + dataset.name) + logging.debug('Using the specified args:') + logging.debug(args) + + + if args.visdom: + vis_title = args.exp_name + vis_legend = ['Location Loss', 'Confidence Loss', 'Total Loss'] + iter_plot = create_vis_plot('Iteration', 'Loss', viz, 'Training ' + vis_title, vis_legend) + epoch_plot = create_vis_plot('Epoch', 'Loss', viz, 'Training ' + vis_title, vis_legend) + + data_loader = data.DataLoader(dataset, args.batch_size, + num_workers=args.num_workers, + shuffle=True, collate_fn=detection_collate, + pin_memory=True) + + logging.debug('Training set size is ' + str(len(dataset))) + + # create batch iterator + batch_iterator = iter(data_loader) + for iteration in range(args.start_iter, cfg['max_iter']): + + # resume training + ssd_net.train() + + t0 = time.time() + + if iteration in cfg['lr_steps']: + step_index += 1 + adjust_learning_rate(args, optimizer, args.gamma, step_index) + + # load train data + try: + images, targets, _ = next(batch_iterator) + except StopIteration: + batch_iterator = iter(data_loader) + images, targets, _ = next(batch_iterator) + + if args.cuda: + images = images.cuda() + targets = [ann.cuda() for ann in targets] + else: + images = Variable(images) + targets = [Variable(ann, volatile=True) for ann in targets] + + # forward + out = ssd_net(images) + + # backprop + optimizer.zero_grad() + loss_l, loss_c = criterion(out, targets) + loss = args.alpha * loss_l + loss_c #TODO. For now alpha should be 1. While plotting alpha is assumed to be 1 + loss.backward() + optimizer.step() + + loc_loss += loss_l.item() + conf_loss += loss_c.item() + + t1 = time.time() + + # Log progress + if iteration % 10 == 0: + logging.debug('timer: %.4f sec.' % (t1 - t0)) + logging.debug('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.item())) + + if args.visdom: + update_vis_plot(iteration, loss_l.item(), viz, loss_c.item(), + iter_plot, epoch_plot, 'append') + + if iteration != 0 and iteration % 1000 == 0: + logging.debug('Saving state, iter:' + str(iteration)) + torch.save(ssd_net.state_dict(), + os.path.join( + 'weights_' + args.exp_name, 'ssd' + str(args.model_type) + args.dataset + + repr(iteration) + '.pth')) + + + if iteration!=0 and (iteration % epoch_size == 0): + epoch += 1 + + torch.save(ssd_net.state_dict(), + os.path.join( + 'weights_' + args.exp_name, 'epoch_ssd' + str(args.model_type) + args.dataset + + repr(epoch) + '.pth')) + + + train_loss = loc_loss + conf_loss + update_vis_plot(epoch, loc_loss, viz, conf_loss, epoch_plot, None, + 'append', epoch_size) + + if args.validation_data != '': + # Validate data + validation_loss = validate(args, ssd_net, criterion, cfg) + + if epoch == 1: + validation_plot = create_validation_plot(epoch, validation_loss, + 'Epoch', 'Loss', viz, 'Validating ' + vis_title, + ['Validation']) + else: + update_validation_plot(epoch, validation_loss, viz, + validation_plot, 'append') + + if validation_loss < min_total_loss: + min_total_loss = validation_loss + torch.save(ssd_net.state_dict(), + os.path.join( + 'weights_' + args.exp_name, 'best_ssd' + str(args.model_type) + args.dataset + + repr(iteration) + '.pth')) + + # reset epoch loss counters + loc_loss = 0 + conf_loss = 0 + + torch.save(ssd_net.state_dict(), + args.exp_name + '' + args.dataset + '.pth') + + logging.debug("Final weights are saved at " + args.exp_name + '' + args.dataset + '.pth') + +def validate(args, net, criterion, cfg): + + validation_batch_size = 1 + try: + # Turn off learning. Go to testing phase + net.eval() + + dataset = GTDBDetection(args, args.validation_data, split='validate', + transform=SSDAugmentation(cfg['min_dim'], mean=MEANS)) + + data_loader = data.DataLoader(dataset, validation_batch_size, + num_workers=args.num_workers, + shuffle=False, collate_fn=detection_collate, + pin_memory=True) + + total = len(dataset) + done = 0 + loc_loss = 0 + conf_loss = 0 + + start = time.time() + + for batch_idx, (images, targets, ids) in enumerate(data_loader): + + done = done + len(images) + logging.debug('processing {}/{}'.format(done, total)) + + if args.cuda: + images = images.cuda() + targets = [ann.cuda() for ann in targets] + else: + images = Variable(images) + targets = [Variable(ann, volatile=True) for ann in targets] + + y = net(images) # forward pass + + loss_l, loss_c = criterion(y, targets) + loc_loss += loss_l.item() # data[0] + conf_loss += loss_c.item() # data[0] + + end = time.time() + logging.debug('Time taken for validation ' + str(datetime.timedelta(seconds=end - start))) + + return (loc_loss + conf_loss) / (total/validation_batch_size) + except Exception as e: + logging.error("Could not validate", exc_info=True) + return 0 + +def adjust_learning_rate(args, optimizer, gamma, step): + """Sets the learning rate to the initial LR decayed by 10 at every + specified step + # Adapted from PyTorch Imagenet example: + # https://github.com/pytorch/examples/blob/master/imagenet/main.py + """ + lr = args.lr * (gamma ** (step)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def visualize(ssd_net, gpu_id): + x = np.zeros((300,300,3)) + x = torch.from_numpy(x).permute(2, 0, 1) + x = x.unsqueeze(0) + x = x.to(gpu_id, dtype=torch.float) + ssd_net.eval() + y = ssd_net(x) + make_dot(y[0], params=dict(ssd_net.named_parameters())).render(filename='ssd_net') + + +def xavier(param): + init.xavier_uniform_(param) + + +def weights_init(m): + if isinstance(m, nn.Conv2d): + xavier(m.weight.data) + m.bias.data.zero_() + +def create_validation_plot(epoch, validation_loss, _xlabel, _ylabel, viz, _title, _legend): + return viz.line( + X=torch.ones((1, 1)).cpu() * epoch, + Y=torch.Tensor([validation_loss]).unsqueeze(0).cpu(), + opts=dict( + xlabel=_xlabel, + ylabel=_ylabel, + title=_title, + legend=_legend + ) + ) + + +def create_vis_plot(_xlabel, _ylabel, viz, _title, _legend): + return viz.line( + X=torch.zeros((1,)).cpu(), + Y=torch.zeros((1, len(_legend))).cpu(), + opts=dict( + xlabel=_xlabel, + ylabel=_ylabel, + title=_title, + legend=_legend + ) + ) + + +def update_validation_plot(epoch, validation_loss, + viz, window, update_type): + + viz.line( + X=torch.ones((1, 1)).cpu() * epoch, + Y=torch.Tensor([validation_loss]).unsqueeze(0).cpu(), + win=window, + update=update_type + ) + + +def update_vis_plot(iteration, loc, viz, conf, window1, window2, update_type, + epoch_size=1): + viz.line( + X=torch.ones((1, 3)).cpu() * iteration, + Y=torch.Tensor([loc, conf, loc + conf]).unsqueeze(0).cpu() / epoch_size, + win=window1, + update=update_type + ) + # initialize epoch plot on first iteration + if iteration == 0: + viz.line( + X=torch.zeros((1, 3)).cpu(), + Y=torch.Tensor([loc, conf, loc + conf]).unsqueeze(0).cpu(), + win=window2, + update=True + ) + +def init_args(): + ''' + Read arguments and initialize directories + :return: args + ''' + + parser = argparse.ArgumentParser( + description='Single Shot MultiBox Detector Training With Pytorch') + train_set = parser.add_mutually_exclusive_group() + parser.add_argument('--dataset', default='GTDB', choices=['GTDB'], + type=str, help='choose GTDB') + parser.add_argument('--dataset_root', default=GTDB_ROOT, + help='Dataset root directory path') + parser.add_argument('--basenet', default='vgg16_reducedfc.pth', + help='Pretrained base model') + parser.add_argument('--batch_size', default=32, type=int, + help='Batch size for training') + parser.add_argument('--resume', default=None, type=str, + help='Checkpoint state_dict file to resume training from') + parser.add_argument('--start_iter', default=0, type=int, + help='Resume training at this iter') + parser.add_argument('--num_workers', default=4, type=int, + help='Number of workers used in data loading') + parser.add_argument('--cuda', default=False, type=bool, + help='Use CUDA to train model') + parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, + help='initial learning rate') + parser.add_argument('--momentum', default=0.9, type=float, + help='Momentum value for optim') + parser.add_argument('--weight_decay', default=5e-4, type=float, + help='Weight decay for SGD') + parser.add_argument('--alpha', default=1.0, type=float, + help='Alpha for the multibox loss') + parser.add_argument('--gamma', default=0.1, type=float, + help='Gamma update for SGD') + parser.add_argument('--visdom', default=False, type=bool, + help='Use visdom for loss visualization') + parser.add_argument('--exp_name', default='math_detector', # changed to exp_name from --save_folder + help='It is the name of the experiment. Weights are saved in the directory with same name.') + parser.add_argument('--layers_to_freeze', default=20, type=float, + help='Number of VGG16 layers to freeze') + parser.add_argument('--model_type', default=300, type=int, + help='Type of ssd model, ssd300 or ssd512') + parser.add_argument('--suffix', default="_10", type=str, + help='Stride % used while generating images or dpi from which images was generated or some other identifier') + parser.add_argument('--training_data', default="training_data", type=str, + help='Training data to use. This is list of file names, one per line') + parser.add_argument('--validation_data', default="", type=str, + help='Validation data to use. This is list of file names, one per line') + parser.add_argument('--use_char_info', default=False, type=bool, + help='Whether to use char position info and labels') + parser.add_argument('--cfg', default="ssd512", type=str, + help='Type of network: either gtdb or math_gtdb_512') + parser.add_argument('--loss_fun', default="fl", type=str, + help='Type of loss: either fl (focal loss) or ce (cross entropy)') + parser.add_argument('--kernel', default="3 3", type=int, nargs='+', + help='Kernel size for feature layers: 3 3 or 1 5') + parser.add_argument('--padding', default="1 1", type=int, nargs='+', + help='Padding for feature layers: 1 1 or 0 2') + parser.add_argument('--neg_mining', default=False, type=bool, + help='Whether or not to use hard negative mining with ratio 1:3') + parser.add_argument('--log_dir', default="logs", type=str, + help='dir to save the logs') + parser.add_argument('--stride', default=0.1, type=float, + help='Stride to use for sliding window') + parser.add_argument('--window', default=1200, type=int, + help='Sliding window size') + parser.add_argument('--pos_thresh', default=0.5, type=float, + help='All default boxes with iou>pos_thresh are considered as positive examples') + + args = parser.parse_args() + + if torch.cuda.is_available(): + if args.cuda: + torch.set_default_tensor_type('torch.cuda.FloatTensor') + if not args.cuda: + logging.warning("WARNING: It looks like you have a CUDA device, but aren't " + + "using CUDA.\nRun with --cuda for optimal training speed.") + torch.set_default_tensor_type('torch.FloatTensor') + else: + torch.set_default_tensor_type('torch.FloatTensor') + + if not os.path.exists("weights_" + args.exp_name): + os.mkdir("weights_" + args.exp_name) + + return args + +if __name__ == '__main__': + + args = init_args() + start = time.time() + + try: + filepath=os.path.join(args.log_dir, args.exp_name + "_" + str(round(time.time())) + ".log") + print('Logging to ' + filepath) + logging.basicConfig(filename=filepath, + filemode='w', format='%(process)d - %(asctime)s - %(message)s', + datefmt='%d-%b-%y %H:%M:%S', level=logging.DEBUG) + + train(args) + except Exception as e: + logging.error("Exception occurred", exc_info=True) + + end = time.time() + logging.debug('Total time taken ' + str(datetime.timedelta(seconds=end - start))) + logging.debug("Training done!") diff --git a/ScanSSD/utils/__init__.py b/ScanSSD/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66e4f7d3b7294736aa99dece42e0d6f973aa4a92 --- /dev/null +++ b/ScanSSD/utils/__init__.py @@ -0,0 +1,2 @@ +from .augmentations import SSDAugmentation +from .visualize import draw_boxes, save_boxes diff --git a/ScanSSD/utils/augmentations.py b/ScanSSD/utils/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..67992ad00e50093366c9ea1adb0320f7c14ee56f --- /dev/null +++ b/ScanSSD/utils/augmentations.py @@ -0,0 +1,424 @@ +import torch +from torchvision import transforms +import cv2 +import numpy as np +import types +from numpy import random +from matplotlib import pyplot as plt +from PIL import Image, ImageOps + +def intersect(box_a, box_b): + max_xy = np.minimum(box_a[:, 2:], box_b[2:]) + min_xy = np.maximum(box_a[:, :2], box_b[:2]) + inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) + return inter[:, 0] * inter[:, 1] + + +def jaccard_numpy(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: Multiple bounding boxes, Shape: [num_boxes,4] + box_b: Single bounding box, Shape: [4] + Return: + jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2]-box_a[:, 0]) * + (box_a[:, 3]-box_a[:, 1])) # [A,B] + area_b = ((box_b[2]-box_b[0]) * + (box_b[3]-box_b[1])) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +class Compose(object): + """Composes several augmentations together. + Args: + transforms (List[Transform]): list of transforms to compose. + Example: + >>> augmentations.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img, boxes=None, labels=None): + for t in self.transforms: + img, boxes, labels = t(img, boxes, labels) + return img, boxes, labels + + +class Lambda(object): + """Applies a lambda as a transform.""" + + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, img, boxes=None, labels=None): + return self.lambd(img, boxes, labels) + + +class ConvertFromInts(object): + def __call__(self, image, boxes=None, labels=None): + return image.astype(np.float32), boxes, labels + + +class SubtractMeans(object): + def __init__(self, mean): + self.mean = np.array(mean, dtype=np.float32) + + def __call__(self, image, boxes=None, labels=None): + image = image.astype(np.float32) + image -= self.mean + return image.astype(np.float32), boxes, labels + + +class ToAbsoluteCoords(object): + def __call__(self, image, boxes=None, labels=None): + height, width, channels = image.shape + boxes[:, 0] *= width + boxes[:, 2] *= width + boxes[:, 1] *= height + boxes[:, 3] *= height + + return image, boxes, labels + + +class ToPercentCoords(object): + def __call__(self, image, boxes=None, labels=None): + height, width, channels = image.shape + boxes[:, 0] /= width + boxes[:, 2] /= width + boxes[:, 1] /= height + boxes[:, 3] /= height + + return image, boxes, labels + + +class Resize(object): + def __init__(self, size=300): + self.size = size + + def __call__(self, image, boxes=None, labels=None): + # plt.imshow(image) + # plt.savefig("eval/bface.png", dpi=600) + # plt.close() + + image = cv2.resize(image, (self.size,self.size), interpolation=cv2.INTER_AREA) + # plt.imshow(image) + # plt.savefig("eval/face.png", dpi=600) + # plt.close() + return image, boxes, labels + + +class RandomSaturation(object): + def __init__(self, lower=0.5, upper=1.5): + self.lower = lower + self.upper = upper + assert self.upper >= self.lower, "contrast upper must be >= lower." + assert self.lower >= 0, "contrast lower must be non-negative." + + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + image[:, :, 1] *= random.uniform(self.lower, self.upper) + + return image, boxes, labels + + +class RandomHue(object): + def __init__(self, delta=18.0): + assert delta >= 0.0 and delta <= 360.0 + self.delta = delta + + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + image[:, :, 0] += random.uniform(-self.delta, self.delta) + image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 + image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 + return image, boxes, labels + + +class RandomLightingNoise(object): + def __init__(self): + self.perms = ((0, 1, 2), (0, 2, 1), + (1, 0, 2), (1, 2, 0), + (2, 0, 1), (2, 1, 0)) + + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + swap = self.perms[random.randint(len(self.perms))] + shuffle = SwapChannels(swap) # shuffle channels + image = shuffle(image) + return image, boxes, labels + + +class ConvertColor(object): + def __init__(self, current='BGR', transform='HSV'): + self.transform = transform + self.current = current + + def __call__(self, image, boxes=None, labels=None): + if self.current == 'BGR' and self.transform == 'HSV': + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + elif self.current == 'HSV' and self.transform == 'BGR': + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + else: + raise NotImplementedError + return image, boxes, labels + + +class RandomContrast(object): + def __init__(self, lower=0.5, upper=1.5): + self.lower = lower + self.upper = upper + assert self.upper >= self.lower, "contrast upper must be >= lower." + assert self.lower >= 0, "contrast lower must be non-negative." + + # expects float image + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + alpha = random.uniform(self.lower, self.upper) + image *= alpha + return image, boxes, labels + + +class RandomBrightness(object): + def __init__(self, delta=32): + assert delta >= 0.0 + assert delta <= 255.0 + self.delta = delta + + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + delta = random.uniform(-self.delta, self.delta) + image += delta + return image, boxes, labels + + +class ToCV2Image(object): + def __call__(self, tensor, boxes=None, labels=None): + return tensor.cpu().numpy().astype(np.float32).transpose((1, 2, 0)), boxes, labels + + +class ToTensor(object): + def __call__(self, cvimage, boxes=None, labels=None): + return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels + + +class RandomSampleCrop(object): + """Crop + Arguments: + img (Image): the image being input during training + boxes (Tensor): the original bounding boxes in pt form + labels (Tensor): the class labels for each bbox + mode (float tuple): the min and max jaccard overlaps + Return: + (img, boxes, classes) + img (Image): the cropped image + boxes (Tensor): the adjusted bounding boxes in pt form + labels (Tensor): the class labels for each bbox + """ + def __init__(self): + self.sample_options = ( + # using entire original input image + None, + # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 + (0.1, None), + (0.3, None), + (0.5, None), + (0.7, None), + (0.9, None), + # randomly sample a patch + (None, None), + ) + + def __call__(self, image, boxes=None, labels=None): + height, width, _ = image.shape + while True: + # randomly choose a mode + mode = random.choice(self.sample_options) + if mode is None: + return image, boxes, labels + + min_iou, max_iou = mode + if min_iou is None: + min_iou = float('-inf') + if max_iou is None: + max_iou = float('inf') + + # max trails (50) + for _ in range(50): + current_image = image + + w = random.uniform(0.3 * width, width) + h = random.uniform(0.3 * height, height) + + # aspect ratio constraint b/t .5 & 2 + if h / w < 0.5 or h / w > 2: + continue + + left = random.uniform(width - w) + top = random.uniform(height - h) + + # convert to integer rect x1,y1,x2,y2 + rect = np.array([int(left), int(top), int(left+w), int(top+h)]) + + # calculate IoU (jaccard overlap) b/t the cropped and gt boxes + overlap = jaccard_numpy(boxes, rect) + + # is min and max overlap constraint satisfied? if not try again + if overlap.max() < min_iou or overlap.min() > max_iou: + continue + + # cut the crop from the image + current_image = current_image[rect[1]:rect[3], rect[0]:rect[2],:] + + # keep overlap with gt box IF center in sampled patch + centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 + + # mask in all gt boxes that above and to the left of centers + m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) + + # mask in all gt boxes that under and to the right of centers + m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) + + # mask in that both m1 and m2 are true + mask = m1 * m2 + + # have any valid boxes? try again if not + if not mask.any(): + continue + + # take only matching gt boxes + current_boxes = boxes[mask, :].copy() + + # take only matching gt labels + current_labels = labels[mask] + + # should we use the box left and top corner or the crop's + current_boxes[:, :2] = np.maximum(current_boxes[:, :2], + rect[:2]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, :2] -= rect[:2] + + current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], + rect[2:]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, 2:] -= rect[:2] + + return current_image, current_boxes, current_labels + + +class Expand(object): + def __init__(self, mean): + self.mean = mean + + def __call__(self, image, boxes, labels): + if random.randint(2): + return image, boxes, labels + + height, width, depth = image.shape + ratio = random.uniform(1, 4) + left = random.uniform(0, width*ratio - width) + top = random.uniform(0, height*ratio - height) + + expand_image = np.zeros( + (int(height*ratio), int(width*ratio), depth), + dtype=image.dtype) + expand_image[:, :, :] = self.mean + expand_image[int(top):int(top + height), + int(left):int(left + width)] = image + image = expand_image + + boxes = boxes.copy() + boxes[:, :2] += (int(left), int(top)) + boxes[:, 2:] += (int(left), int(top)) + + return image, boxes, labels + + +class RandomMirror(object): + def __call__(self, image, boxes, classes): + _, width, _ = image.shape + if random.randint(2): + image = image[:, ::-1] + boxes = boxes.copy() + boxes[:, 0::2] = width - boxes[:, 2::-2] + return image, boxes, classes + + +class SwapChannels(object): + """Transforms a tensorized image by swapping the channels in the order + specified in the swap tuple. + Args: + swaps (int triple): final order of channels + eg: (2, 1, 0) + """ + + def __init__(self, swaps): + self.swaps = swaps + + def __call__(self, image): + """ + Args: + image (Tensor): image tensor to be transformed + Return: + a tensor with channels swapped according to swap + """ + # if torch.is_tensor(image): + # image = image.data.cpu().numpy() + # else: + # image = np.array(image) + image = image[:, :, self.swaps] + return image + + +class PhotometricDistort(object): + def __init__(self): + self.pd = [ + RandomContrast(), + ConvertColor(transform='HSV'), + RandomSaturation(), + RandomHue(), + ConvertColor(current='HSV', transform='BGR'), + RandomContrast() + ] + self.rand_brightness = RandomBrightness() + self.rand_light_noise = RandomLightingNoise() + + def __call__(self, image, boxes, labels): + im = image.copy() + im, boxes, labels = self.rand_brightness(im, boxes, labels) + if random.randint(2): + distort = Compose(self.pd[:-1]) + else: + distort = Compose(self.pd[1:]) + im, boxes, labels = distort(im, boxes, labels) + return self.rand_light_noise(im, boxes, labels) + + +class SSDAugmentation(object): + def __init__(self, size=300, mean=(104, 117, 123)): + self.mean = mean + self.size = size + self.augment = Compose([ + ConvertFromInts(), + ToAbsoluteCoords(), + PhotometricDistort(), + Expand(self.mean), + RandomSampleCrop(), + #RandomMirror(), + ToPercentCoords(), + Resize(self.size), + SubtractMeans(self.mean) + ]) + + def __call__(self, img, boxes, labels): + return self.augment(img, boxes, labels) diff --git a/ScanSSD/utils/helpers.py b/ScanSSD/utils/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8daa571af7ed1f9298868e740da8a88a69881c --- /dev/null +++ b/ScanSSD/utils/helpers.py @@ -0,0 +1,10 @@ +import os +import numpy as np + +def get_freer_gpu(): + ''' + Find which gpu is free + ''' + os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') + memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] + return int(np.argmax(memory_available)) \ No newline at end of file diff --git a/ScanSSD/utils/visualize.py b/ScanSSD/utils/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..a3735116cadc85123677ef04e8102bdca3b330e2 --- /dev/null +++ b/ScanSSD/utils/visualize.py @@ -0,0 +1,202 @@ +''' +This file contains functions to visualize the heatmap and detected bounding boxes +''' + +import matplotlib +matplotlib.use('Agg') + +import matplotlib.pyplot as plt +import matplotlib.patches as patches +import os +import numpy as np +import cv2 + +def draw_stitched_boxes(im, data, outpath): + + # Create figure and axes + fig, ax = plt.subplots(1) + + # sort based on the confs. Confs is column 4 + data = data[data[:, 4].argsort()] + + # Display the image + ax.imshow(im) + + width, height, channels = im.shape + heatmap = np.zeros([width, height]) + + for box in data: + heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = box[4] + + # Following line makes sure that all the heatmaps are in the scale, 0 to 1 + # So color assigned to different scores are consistent across heatmaps for + # different images + heatmap[0:1, 0:1] = 1 + heatmap[0:1, 1:2] = 0 + + plt.imshow(heatmap, alpha=0.4, cmap='hot', interpolation='nearest') + plt.colorbar() + + plt.title("Stitching visualization") + plt.show() + plt.savefig(outpath, dpi=600) + plt.close() + + +def draw_all_boxes(im, data, recognized_boxes, gt_boxes, outpath): + + if len(data) == 0: + return + + # Create figure and axes + fig, ax = plt.subplots(1) + + # sort based on the confs. Confs is column 4 + data = data[data[:, 4].argsort()] + + # Display the image + ax.imshow(im) + + width, height, channels = im.shape + heatmap = np.zeros([width, height]) + + if data is not None: + for box in data: + heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = box[4] + #rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], + # linewidth=0.25, edgecolor='m', facecolor='none') + #Add the patch to the Axes + #ax.add_patch(rect) + + if recognized_boxes is not None: + # recognized boxes are green + for box in recognized_boxes: + rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], + linewidth=1, edgecolor='g', facecolor='none') + # Add the patch to the Axes + ax.add_patch(rect) + + + if gt_boxes is not None: + # ground truth are red + for box in gt_boxes: + rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], + linewidth=0.25, edgecolor='b', facecolor='none') + # Add the patch to the Axes + ax.add_patch(rect) + + # Following line makes sure that all the heatmaps are in the scale, 0 to 1 + # So color assigned to different scores are consistent across heatmaps for + # different images + heatmap[0:1, 0:1] = 1 + heatmap[0:1, 1:2] = 0 + + plt.imshow(heatmap, alpha=0.4, cmap='hot', interpolation='nearest') + plt.colorbar() + + plt.title("Stitching visualization") + plt.show() + plt.savefig(outpath, dpi=600) + plt.close() + + +def draw_boxes_cv(image, recognized_boxes, gt_boxes, outpath): + + ''' + :param image + :param recognized_boxes + :param outpath: save as outpath. Should be complete image path with extension + :return: + ''' + + #(BGR) + # detected is green + for box in recognized_boxes: + cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 3) + + # ground truth is blue + for box in gt_boxes: + cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 3) + + cv2.imwrite(outpath, image) + + +def save_boxes(args, recognized_boxes, recognized_scores, img_id): + + if len(recognized_scores) < 1 and len(recognized_boxes) < 1: + return + + pdf_name = img_id.split("/")[0] + math_csv_path = os.path.join(args.save_folder, args.exp_name, pdf_name + ".csv") + + if not os.path.exists(os.path.dirname(math_csv_path)): + os.makedirs(os.path.dirname(math_csv_path)) + + math_output = open(math_csv_path, 'a') + + recognized_boxes = np.concatenate((recognized_boxes,np.transpose([recognized_scores])),axis=1) + + page_num = int(img_id.split("/")[-1]) + + col = np.array([int(page_num) - 1] * recognized_boxes.shape[0]) + math_regions = np.concatenate((col[:, np.newaxis], recognized_boxes), axis=1) + + np.savetxt(math_output, math_regions, fmt='%.2f', delimiter=',') + math_output.close() + + # + # + # for i, box in enumerate(recognized_boxes): + # math_output.write(str(box[0]) + ',' + str(box[1]) + ',' + str(box[2]) + ',' + + # str(box[3]) + ',' + str(recognized_scores[i]) + '\n') + # + +def draw_boxes(args, im, recognized_boxes, recognized_scores, boxes, confs, scale, img_id): + + path = os.path.join("eval", args.exp_name, img_id + ".png") + + if not os.path.exists(os.path.dirname(path)): + os.makedirs(os.path.dirname(path)) + + # Create figure and axes + fig,ax = plt.subplots(1) + scale = scale.cpu().numpy() + + # Display the image + ax.imshow(im) + + width, height, channels = im.shape + heatmap = np.zeros([width, height]) + + if len(recognized_scores) > 1 and len(recognized_boxes) > 1: + + # Recognition heatmap + data = np.concatenate((recognized_boxes,np.transpose([recognized_scores])),axis=1) + data = data[data[:, 4].argsort()] + + for box in data: + heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = box[4] + + for box in recognized_boxes: + rect = patches.Rectangle((box[0], box[1]), box[2]-box[0], box[3] - box[1], + linewidth=1, edgecolor='g', facecolor='none') + #Add the patch to the Axes + ax.add_patch(rect) + + # Following line makes sure that all the heatmaps are in the scale, 0 to 1 + # So color assigned to different scores are consistent across heatmaps for + # different images + heatmap[0:1, 0:1] = 1 + heatmap[0:1, 1:2] = 0 + + plt.imshow(heatmap, alpha=0.4, cmap='hot', interpolation='nearest') + plt.colorbar() + + plt.title(args.exp_name) + plt.show() + plt.savefig(path, dpi=600) + plt.close() + + +if __name__ == "__main__": + draw_boxes() \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..45d99ac83d9a42f2f26ad8960e51fe8458563fce --- /dev/null +++ b/app.py @@ -0,0 +1,293 @@ +import yaml +from typing import List +import numpy as np +from PIL import Image +from pathlib import Path +from collections import defaultdict + +import cv2 +import torch +from torchvision.ops import nms +from timm.models.resnetv2 import ResNetV2 +from timm.models.layers import StdConv2dSame + +from pdf2image import convert_from_bytes + +from ScanSSD.detect_flow import MathDetector +from HybridViT.recog_flow import MathRecognition +from utils.p2l_utils import get_rolling_crops, postprocess + +import streamlit + + +class DetectCfg(): + def __init__ (self): + self.cuda = True if torch.cuda.is_available() else False + self.kernel = (1, 5) + self.padding = (0, 2) + self.phase = 'test' + self.visual_threshold = 0.8 + self.verbose = False + self.exp_name = 'SSD' + self.model_type = 512 + self.use_char_info = False + self.limit = -1 + self.cfg = 'hboxes512' + self.batch_size = 32 + self.num_workers = 4 + self.neg_mining = True + self.log_dir = 'logs' + self.stride = 0.1 + self.window = 1200 + +class App: + title = 'Math Expression Recognition Demo \n\n Note: For Math Detection, we reuse the model from this repo [ScanSSD: Scanning Single Shot Detector for Math in Document Images](https://github.com/MaliParag/ScanSSD).\n\nThis demo aim to present the effciency of our method [A Hybrid Vision Transformer Approach for Mathematical Expression Recognition](https://ieeexplore.ieee.org/document/10034626) in recognizing math expression in document images.' + + def __init__(self): + self._model_cache = {} + self.detect_model = MathDetector('saved_models/math_detect/AMATH512_e1GTDB.pth', DetectCfg()) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max((672, 192))//32, global_pool='avg', in_chans=1, drop_rate=.05, + preact=True, stem_type='same', conv_layer=StdConv2dSame).to(device) + self.image_resizer.load_state_dict(torch.load('saved_models/resizer/image_resizer.pth', map_location=device)) + self.image_resizer.eval() + + def detect_preprocess(self, img_list): + if isinstance(img_list, Image.Image): + img_list = [img_list] + + new_images = [] + + for temp_image in img_list: + img_size = 1280 + # convert image to numpy array + temp_image = np.array(temp_image) + img = cv2.resize(temp_image, (img_size, int(img_size * temp_image.shape[0] / temp_image.shape[1]))) + new_images.append(img) + + return new_images + + def _get_model(self, name): + if name in self._model_cache: + return self._model_cache[name] + + with open('recog_cfg.yaml', 'r') as f: + recog_cfg = yaml.safe_load(f) + + model_cfg = {} + model_cfg.update(recog_cfg['common']) + model_cfg.update(recog_cfg[name]) + recog_model = MathRecognition(model_cfg, self.image_resizer if model_cfg['resizer'] else None + ) + self._model_cache[name] = recog_model + + return recog_model + + def _get_boxes(self, img, temp_bb): + temp_bb[0] = max(0, temp_bb[0] - int(0.05 * (temp_bb[2] - temp_bb[0]))) + temp_bb[1] = max(0, temp_bb[1] - int(0.05 * (temp_bb[3] - temp_bb[1]))) + temp_bb[2] = min(img.shape[1], temp_bb[2] + int(0.05 * (temp_bb[2] - temp_bb[0]))) + temp_bb[3] = min(img.shape[0], temp_bb[3] + int(0.05 * (temp_bb[3] - temp_bb[1]))) + + # convert to int + temp_bb = [int(x) for x in temp_bb] + + return temp_bb + + @torch.inference_mode() + def math_detection(self, page_lst: List[np.ndarray]): + res = [] + + batch_size = 32 + threshold = 0.9 + iou = 0.1 + + for idx, temp_image in enumerate(page_lst): + crops_list, padded_crops_list, crops_info_list = get_rolling_crops(temp_image, stride=[128, 128]) + + scores_list = [] + wb_list = [] + for i in range(0, len(padded_crops_list), batch_size): + batch = padded_crops_list[i:i+batch_size] + window_borders, scores = self.detect_model.DetectAny(batch, threshold) + scores_list.extend(scores) + wb_list.extend(window_borders) + + # change crops to original image coordinates + bb_list, s_list = postprocess(wb_list, scores_list, crops_info_list) + + # convert to torch tensors + bb_torch = torch.tensor(bb_list).float() + scores_torch = torch.tensor(s_list) + + # perform non-maximum suppression + # check if bb_torch is empty + if bb_torch.shape[0] == 0: + res.append(([], [])) + continue + + indices = nms(bb_torch, scores_torch, iou) + + bb_torch = bb_torch[indices] + new_bb_list = bb_torch.int().tolist() + + for i in range(len(new_bb_list)): + save_name = 'Page ' + str(idx) + '-Expr ' + str(i) if len(page_lst) > 1 else 'Expr ' + str(i) + temp_bb = self._get_boxes(temp_image, new_bb_list[i][:]) + crop_expr = temp_image[temp_bb[1]:temp_bb[3], temp_bb[0]:temp_bb[2]] + crop_expr = Image.fromarray(crop_expr) + res.append((save_name, crop_expr)) + + return res + + def math_recognition(self, model_name, res: List): + model = self._get_model(model_name) + final_res = [] + for item in res: + name, crop_expr = item + if isinstance(crop_expr, list): + continue + latex_str = model(crop_expr, name=name) + final_res.append((name, crop_expr, latex_str)) + + return final_res + + def __call__(self, model_name, image_list, use_detect): + #Detect + if use_detect: + new_images = self.detect_preprocess(image_list) + res = self.math_detection(page_lst=new_images) + else: + res = [('latex_pred', image_list[0])] + #Recog + final_res = self.math_recognition(model_name, res) + display_name, origin_img, latex_pred = tuple([list(item) for item in zip(*final_res)]) + return display_name, origin_img, latex_pred + + +def api(): + app = App() + streamlit.set_page_config(page_title='Thesis Demo', layout='wide') + streamlit.title(f'{app.title}') + streamlit.markdown(f""" + To use this interactive demo and reproduced models: + 1. Select what type of input data you want to get prediction. + 2. Upload your own image or pdf file (or select from the given examples). + 3. If input file is in pdf format, choose start page and end page. + 4. Click **Extract**. + + **Note: Current version of this demo only support single file upload for both Image and PDF option.** + """ + ) + + # model_name = streamlit.radio( + # label='The Math Recognition model to use', + # options=app.models + # ) + + extract_option = streamlit.radio( + label='Select type of input for prediction', + options=('Math expression image only', 'Full document image'), + + ) + + uploaded_file = streamlit.file_uploader( + 'Upload an image/pdf file', + type=['png', 'jpg', 'pdf'], + accept_multiple_files=False + ) + + if uploaded_file is not None: + if Path(uploaded_file.name).suffix == '.pdf': + bytes_data = uploaded_file.read() + + image_lst = convert_from_bytes(bytes_data, dpi=160, grayscale=True) + image_lst = [img.convert('RGB') for img in image_lst] + + container = streamlit.container() + range_cols = container.columns(2) + start_page = range_cols[0].number_input(label='Start page', min_value=0, max_value=len(image_lst)-2) + end_page = range_cols[1].number_input(label='End page', min_value=1, max_value=len(image_lst)-1) + + if start_page <= end_page: + image_lst = image_lst[start_page:end_page+1] + cols = streamlit.columns(len(image_lst)) + for i in range(len(cols)): + with cols[i]: + img_shape = image_lst[i].size + streamlit.image(image_lst[i], width=1024, caption=f'Page: {str(i)} Image shape: {str(img_shape)}', use_column_width='auto') + else: + image = Image.open(uploaded_file).convert('RGB') + image_lst = [image] + img_shape = image.size + streamlit.image(image, width=1024, caption='Image shape: ' + str(img_shape)) + else: + streamlit.text('\n') + + if streamlit.button('Extract'): + if uploaded_file is not None and image_lst is not None: + with streamlit.spinner('Computing'): + try: + use_detect = True + if extract_option == 'Math expression image only': + use_detect = False + model_name = 'version1' + else: + model_name = 'version2' + + display_name, origin_img, latex_code = app(model_name, image_lst, use_detect) + + if Path(uploaded_file.name).suffix == '.pdf': + page_dict = defaultdict(list) + for name, img, pred in zip(display_name, origin_img, latex_code): + name_components = name.split('-') + if len(name_components) <= 1: + page_name = 'Page0' + else: + page_name = name_components[0] + page_dict[page_name].append((img, pred)) + + tab_lst = streamlit.tabs(list(page_dict.keys())) + + for tab, page_name in zip(tab_lst, list(page_dict.keys())): + for idx, item in enumerate(page_dict[page_name]): + container = tab.container() + col_latex, col_render, col_org = container.columns(3, gap='large') + + if idx == 0: + col_latex.header('Predicted LaTeX') + col_render.header('Rendered Image') + col_org.header('Cropped Image') + + render_latex = f'$\\displaystyle {item[-1]}$' + col_latex.code(item[-1], language='latex') + col_render.markdown(render_latex) + img = np.asarray(item[0]) + col_org.image(img) + else: + for idx, (name, org, latex) in enumerate(zip(display_name, origin_img, latex_code)): + container = streamlit.container() + col_latex, col_render, col_org = container.columns(3, gap='large') + + if idx == 0: + col_latex.header('Predicted LaTeX') + col_render.header('Rendered Image') + col_org.header('Cropped Image') + + render_latex = f'$\\displaystyle {latex}$' + col_latex.code(latex, language='latex') + col_render.markdown(render_latex) + org = np.asarray(org) + col_org.image(org) + + except Exception as e: + streamlit.error(e) + else: + streamlit.error('Please upload an image.') + +if __name__ == '__main__': + print(f"Is CUDA available: {torch.cuda.is_available()}") + # True + print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") + # Tesla T4 + api() diff --git a/recog_cfg.yaml b/recog_cfg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b97557a6571a4812b0112294ad9c07ef07e24d4 --- /dev/null +++ b/recog_cfg.yaml @@ -0,0 +1,61 @@ +common: + imgH: null + imgW: null + + FeatureExtraction: + name: 'None' + + SequenceModeling: + name: 'ViT' + params: + backbone: + name: 'resnet' + input_channel: 1 # gray scale + output_channel: 512 + gcb: False + fix_embed: True + input_channel: 1 + patching_style: '2d' + patch_size: [2,2] + depth: 6 + num_heads: 8 + hidden_size: 256 + + Prediction: + name: 'Attnv2' + params: + seqmodel: 'TFM' + input_size: 256 + hidden_size: 256 + kernel_size: 2 + kernel_dim: 128 + embed_target: True + enc_init: True + attn_type: 'coverage' + method: 'concat' + teacher_forcing: 1.0 + droprate: 0.25 + + mean: 0.5 + std: 0.5 + pad: True + beam_size: 10 + rgb: False + character: [] + resizer: False + + +version1: + max_dimension: [800, 800] + min_dimension: [32, 32] + batch_max_length: 500 + vocab: 'saved_models/math_recog/version1/vocab.txt' + weight_path: 'saved_models/math_recog/version1/best_weights.pth' + + +version2: + max_dimension: [224, 960] + min_dimension: [32, 32] + batch_max_length: 200 + vocab: 'saved_models/math_recog/version2/vocab.txt' + weight_path: 'saved_models/math_recog/version2/best_weights.pth' diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b964467b48e1a2e0cca0d3588aa9d06eabfb12b0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,23 @@ +einops==0.6.0 +nltk==3.5 +numpy==1.24.3 +python-levenshtein==0.20.9 +PyYAML==6.0 +timm==0.6.13 +--extra-index-url https://download.pytorch.org/whl/cu113 +torch==2.0.0 +# torch==2.0.0 +--extra-index-url https://download.pytorch.org/whl/cu113 +torchvision==0.15.1 +# torchvision==0.15.1 +pdf2image +scipy==1.10.1 +open-fonts==0.1 +natsort==8.2.0 +levenshtein==0.20.9 +distance==0.1.3 +opencv-python-headless==4.7.0.72 +pillow==9.4.0 +scikit-image==0.19.3 +albumentations==1.3.0 +streamlit==1.22.0 \ No newline at end of file diff --git a/saved_models/math_detect/AMATH512_e1GTDB.pth b/saved_models/math_detect/AMATH512_e1GTDB.pth new file mode 100644 index 0000000000000000000000000000000000000000..c8e4e83b0d2b622cfb5d5dd45294baedef5656f6 --- /dev/null +++ b/saved_models/math_detect/AMATH512_e1GTDB.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a15cb764a0d281fa678bf256d258499759e7fcc0fb055bff76889336962b4b59 +size 96601505 diff --git a/saved_models/math_recog/version1/best_weights.pth b/saved_models/math_recog/version1/best_weights.pth new file mode 100644 index 0000000000000000000000000000000000000000..f42c404afa19532b9f426aec670f18c5894cdbcc --- /dev/null +++ b/saved_models/math_recog/version1/best_weights.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:621966468160fe8ed39731d7729810d95cbf777ba4068a2d409a5eefe0b1fb39 +size 206280682 diff --git a/saved_models/math_recog/version1/vocab.txt b/saved_models/math_recog/version1/vocab.txt new file mode 100644 index 0000000000000000000000000000000000000000..38e69c18593879d42ac45a3e16a5204e0b967886 --- /dev/null +++ b/saved_models/math_recog/version1/vocab.txt @@ -0,0 +1,496 @@ +} +{ +_ +^ +2 +( +) += +1 +- +, +\frac ++ +i +0 +x +n +. +\, +d +a +\mu +e +k +m +r +c +p +\partial +\alpha +t +A +~ +\; +3 +j +s +l +\left( +\right) +g +4 +\ +\nu +\prime +\pi +z +b +\phi +| +\mathrm +\cal +\delta +f +N +q +\lambda +T +S +\beta +] +R +[ +\bar +\int +D +M +L +\operatorname +B +F +\sigma +y +& +\\ +\theta +\gamma +\psi +h +/ +\hat +\sqrt +H +\sum +u +\tilde +\rho +o +\tau +C +G +P +V +I +E +\omega +X +\epsilon +J +\bf +\eta +Q +\Phi +\xi +v +\quad +\vec +\Gamma +K +\infty +5 +\right] +\left[ +U +\Lambda +\pm +\dot +W +Z +\begin{array} +\end{array} +\varphi +* +\Delta +\rangle +6 +w +\chi +\Omega +; +\kappa +\qquad +\} +\{ +\Psi +\equiv +8 +\cdot +\overline +\! +\langle +\rightarrow +> +\dagger +\varepsilon +\zeta +\nabla +< +O +Y +: +\Sigma +\cdots +\mathcal +\ldots +\ell +\left\{ +\: +\sim +\otimes +\wedge +\Pi +! +\operatorname* +7 +\prod +\hspace +\hbar +\in +\vert +9 +\widetilde +\right\} +\to +\Big +\Theta +\mid +\times +\right| +\mathbf +\underline +\ast +\dots +\leq +\left| +\approx +\star +\widehat +\stackrel +\right. +\displaystyle +\big +\perp +\left. +\geq +\mp +\simeq +\dag +\vartheta +\Bigr +\right\rangle +' +\neq +\Bigl +\circ +\longrightarrow +\oint +\biggl +\biggr +\bigg +\textstyle +\ddot +\left\langle +\not +\bigl +\oplus +\bigr +\boldmath +\Xi +\propto +\check +\nonumber +\triangle +\le +\varrho +\ge +\forall +\scriptscriptstyle +\imath +\right> +\| +-- +\lbrack +\sp +\bot +\it +\leftrightarrow +\Rightarrow +\parallel +\mapsto +\subset +\textrm +\iota +\l +\scriptsize +\Bigg +\left< +\binom +\overrightarrow +\ll +\jmath +\phantom +\sf +\cong +\tiny +\ne +\gg +\Biggr +\d +\Biggl +\o +\Upsilon +\breve +\L +\vee +\bigoplus +\bullet +\small +\wp +\scriptstyle +\atop +\varpi +\downarrow +\kern +\# +\vdots +\uparrow +\cap +\rbrack +\Im +\supset +\sb +\slash +\hline +\cup +\Re +\Longrightarrow +\mit +\upsilon +\underbrace +\acute +\varsigma +\lbrace +\protect +\rbrace +\O +\vspace +\bigtriangleup +\Leftrightarrow +\S +\mathsf +` +\longleftrightarrow +\i +\leftarrow +\Vert +\footnotesize +\ddots +\rightharpoonup +\Large +\Longleftrightarrow +\enspace +\right\vert +mm +\left\vert +\raisebox +\cdotp +ule +\bigotimes +\put +\makebox +\tt +\emptyset +\doteq +\hfill +\P +\overleftarrow +\large +\left\| +\right\| +\textbf +\mathop +\vphantom +\llap +\backslash +" +\sharp +\buildrel +\raise +\sl +\flat +\ref +\odot +\noalign +\mathit +\label +\textup +cm +\bigcup +\strut +\/ +\longmapsto +\rfloor +\unitlength +\overbrace +\thinspace +\colon +\subseteq +\setlength +\ni +\pounds +\diamond +\_ +\fbox +\ominus +\line +\enskip +[object +Object] +\bigwedge +\aleph +\circle +? +\bigtriangledown +\lfloor +\bigcap +\vrule +\smallskip +\b +\land +\bmod +\space +\left\lbrack +\right\rbrack +\vskip +\hookrightarrow +\rlap +\diamondsuit +\hrule +\natural +\pmod +\setminus +\ddagger +\vline +\textit +\- +\texttt +\lower +\longleftarrow +\c +pt +\right\rfloor +\left\lbrace +\right\rbrace +\relax +\& +\normalsize +\bigm +\thicklines +0.14 +8.5 +\framebox +\sc +\hfil +\top +0.4 +\vdash +\j +\textsf +\mkern +\textnormal +\supseteq +\medskip +\exists +\smash +\surd +\m +\Biggm +\sqcup +\null +\special +0.1 +\itshape +\lceil +\do +\lefteqn +\Huge +--- +\renewcommand +\arraystretch +\unboldmath +\prec +\LARGE +\cite +\Longleftarrow +\triangleright +\ss +\ensuremath +\amalg +\rightleftharpoons +\grave +\' +\hphantom +\protectu +\asymp +\oslash +\setcounter +\smile +\ae +\arraycolsep +\vcenter +\Bigm +\ooalign +\crcr +\skew +\* +0.5 +\AA +\sqcap +\vss +\mathbin +\left\lfloor +\em +\succeq +\bigsqcup +\lq +\nolinebreak +\multicolumn +\parbox +\multiput +\SS +\notin +\fboxsep +\mathversion +\ddag +\lgroup +\rgroup +\arrowvert +\mathrel +3.1 +3.2 +\Downarrow +\smallint +\mskip +in +\hss diff --git a/saved_models/math_recog/version2/best_weights.pth b/saved_models/math_recog/version2/best_weights.pth new file mode 100644 index 0000000000000000000000000000000000000000..c2714a8195c713b2ea63043b8188122cc458d66c --- /dev/null +++ b/saved_models/math_recog/version2/best_weights.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98e839849086568985b55b37512aee2ea1cf6d922f872c91f61d0565065c415e +size 204731124 diff --git a/saved_models/math_recog/version2/vocab.txt b/saved_models/math_recog/version2/vocab.txt new file mode 100644 index 0000000000000000000000000000000000000000..78af42f7ee1c981b553c8c136128a32f0c53fcaf --- /dev/null +++ b/saved_models/math_recog/version2/vocab.txt @@ -0,0 +1,578 @@ +{ +} +_ +^ +2 +( +) +1 += +- +, +\frac ++ +0 +x +i +. +n +\, +d +a +\mu +m +k +t +e +3 +r +p +\mathrm +A +\right) +\left( +c +4 +~ +s +\alpha +\; +f +\pi +| +\partial +j +b +g +\prime +l +q +\nu +\ +z +M +T +y +R +N +B +L +S +\phi +\lambda +\int +F +/ +\beta +D +] +u +\delta +\bar +[ +P +\sqrt +V +\gamma +\cal +\sigma +\theta +\mathbf +H +E +C +\sum +v +Q +h +G +\rho +& +X +\psi +\hat +5 +\omega +I +\tau +\tilde +\vec +\cdot +6 +\epsilon +\Delta +o +\right] +\left[ +\infty +\Gamma +\eta +\rangle +K +\bf +U +J +\quad +W +\xi +Z +* +8 +\Phi +\\ +\begin{array} +\end{array} +\varphi +w +\Lambda +\pm +; +\langle +\} +\{ +\sin +\chi +\dot +\cos +\equiv +\qquad +\! +\to +\Omega +7 +\overline +\ln +\rightarrow +Y +O +\kappa +\times +< +\mathcal +\Psi +: +9 +\in +> +\varepsilon +\cdots +\dagger +\ldots +\exp +\nabla +\ell +\operatorname +\Sigma +\left\{ +\zeta +\sim +\leq +\mathbb +! +\hbar +\: +\approx +\Pi +\right| +\log +\right\} +\otimes +\mid +\Big +\left| +\simeq +\prod +\vert +\perp +\wedge +\dots +\widetilde +\boldsymbol +\Theta +\lim +\geq +\big +\circ +\ast +\widehat +\| +\stackrel +\tan +\right. +\bigg +\right\rangle +\neq +\sinh +\displaystyle +\cosh +\det +\star +\underline +\left. +\mathfrak +\biggr +\biggl +\textstyle +\Bigr +\Bigl +\vartheta +\dag +\mp +\left\langle +\oplus +\forall +\propto +\longrightarrow +\uparrow +\boldmath +\oint +\bigl +\mapsto +\bigr +\binom +\ddot +' +\Xi +\cong +\Bigg +\mathsf +\le +\cot +\bot +\scriptstyle +\Rightarrow +\cap +\subset +\ll +\leftrightarrow +\nonumber +\arctan +\parallel +\tanh +\textrm +\triangle +\subseteq +\varrho +\max +\cup +\scriptscriptstyle +-- +\textbf +\land +\right> +[object +Object] +\ge +\check +\scriptsize +\imath +\gg +\cfrac +\min +\not +\colon +\bullet +\exists +\Biggl +\left< +\downarrow +\Biggr +\overrightarrow +\right\| +\left\| +\iota +\Upsilon +\neg +\lbrack +\it +\l +\jmath +\lor +\sp +\d +\bigcup +\vee +\setminus +\Im +\sup +\sf +\dim +\tiny +\underbrace +\wp +\Re +\sec +\Pr +\phantom +\overset +\rfloor +\atop +\bigoplus +\# +\L +\pmod +\supset +\coth +\ne +\Leftrightarrow +\textsf +\odot +\aleph +\csc +\slash +\o +\leftarrow +\bigwedge +\right\rfloor +\right\vert +\mathit +\lfloor +\small +\left\lfloor +\upsilon +\gcd +\arcsin +\bmod +\left\vert +\iint +\iff +\lnot +\mit +\varpi +\arg +\deg +\Longrightarrow +\iiint +\vdots +\sb +\arccos +\rbrack +\breve +\Vert +\emptyset +\rbrace +\underset +\lbrace +\rlap +\inf +\implies +\S +\ker +\angle +" +\xrightarrow +\backslash +\longleftrightarrow +\acute +\footnotesize +\top +\vdash +\P +\doteq +\mathop +\bigcap +\Longleftrightarrow +\triangleq +\i +\limsup +\bigstar +\ddots +\leqslant +\varsigma +\bigtriangleup +\lVert +\rVert +\enspace +\O +\tt +\ominus +\vphantom +\bigotimes +\large +\protect +\lceil +\overleftarrow +\overbrace +\ddagger +\varnothing +\sharp +\rceil +\div +\/ +\left\lceil +\right\rceil +\geqslant +\cdotp +\rightharpoonup +\flat +\subsetneq +\supseteq +\dotsc +\notin +\ni +\liminf +\dotsb +\bigvee +\triangleleft +\_ +\bigtriangledown +\Vdash +? +\sl +\enskip +\hfill +\Large +\strut +\hookrightarrow +\left\lbrace +\right\rbrace +\b +\left\lbrack +\right\rbrack +\models +\ref +\textup +\smile +\longmapsto +\diamond +\* +\& +\smallsetminus +\mod +\right/ +\makebox +\thinspace +\intercal +\Vec +\sqcup +\llap +\surd +\pounds +\c +\varprojlim +\texttt +\left\Vert +\hfil +\- +� +\rtimes +\right\Vert +\oslash +\bigm +\j +\mathrel +\triangleright +\SS +\rightleftharpoons +\twoheadrightarrow +\frown +\rightharpoondown +\textit +\nearrow +\prec +\therefore +\diamondsuit +\bigsqcup +\left/ +\varinjlim +\mathord +\searrow +\longleftarrow +\lg +--- +\lesssim +\sc +\complement +\nleftrightarrow +\space +\bowtie +\vDash +\beth +\normalsize +\lrcorner +\overleftrightarrow +\natural +\eqno +\succ +\smallskip +\Biggm +\smash +\textnormal +\upharpoonright +\Uparrow +\gtrsim +\rightleftarrows +\sqsubseteq +\null +\preceq +\allowbreak +\Hat +\leftrightarrows +\ss +\Longleftarrow +@ +\rightrightarrows +\supsetneq +\sqcap +\AA +\pmb +\bumpeq +\mathnormal +\hphantom +\barwedge +\Bigm +\substack +\LARGE +` +\fbox +\ensuremath +\amalg +\asymp +\right[ +\mathstrut +\Tilde +\Bar +\digamma +\right( +\sideset +\a +\dotsm +\lgroup +\rgroup +\Downarrow +\ointop +\coprod +\succeq +\framebox +\vline +\smallint +\gets +\Leftarrow +\cite +\varliminf +\doublebarwedge +\textsc +\newline +\nleq +\footnote +\leqq +\atopwithdelims +\medskip +\pitchfork +\left] +\kern +\nulldelimiterspace +\ddag +\hom +\varkappa +\gimel +\ltimes +\ae +\notag +\lefteqn +\nrightarrow +\subsetneqq +\mathtt +\hspace +\fill +\do diff --git a/saved_models/resizer/image_resizer.pth b/saved_models/resizer/image_resizer.pth new file mode 100644 index 0000000000000000000000000000000000000000..3a0a1abe8e49c7af3e4c504909a875cadd564d34 --- /dev/null +++ b/saved_models/resizer/image_resizer.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c3820659985ad142b526490bb25c23d977176ac2073591b3bddada692718458 +size 19441973 diff --git a/utils/p2l_utils.py b/utils/p2l_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ffed8567c8dce90b5d8b742fe44917b6e0c2c8af --- /dev/null +++ b/utils/p2l_utils.py @@ -0,0 +1,71 @@ +import math +import numpy as np + +def get_rolling_crops(image, stride = [128, 128], window_size = 512): + # as of now stride is not implemented + image_height, image_width, channels = image.shape + + + # Compute the number of rolling windows + # nwindows_vertical = math.ceil(image_height / window_size) + # nwindows_horizontal = math.ceil(image_width / window_size) + nwindows_vertical = math.ceil((image_height - window_size) / stride[0]) + 1 + nwindows_horizontal = math.ceil((image_width - window_size) / stride[1]) + 1 + + print(f"Number of windows: {nwindows_vertical} x {nwindows_horizontal}") + crops_list = [] + padded_crops_list = [] + crops_info_list = [] + + for i in range(nwindows_vertical): + for j in range(nwindows_horizontal): + # window_x_start = j * window_size + window_x_start = j * stride[1] + window_x_end = min(window_x_start + window_size, image_width) + # window_y_start = i * window_size + window_y_start = i * stride[0] + window_y_end = min(window_y_start + window_size, image_height) + window_width = window_x_end - window_x_start + window_height = window_y_end - window_y_start + + rolling_window = image[window_y_start:window_y_end, window_x_start:window_x_end] + + # create new image of desired size with white background + color = (255,255,255) + padded_window = np.full((window_size,window_size, channels), color, dtype=np.uint8) + + # compute center offset + x_center = (window_size - window_width) // 2 + y_center = (window_size - window_height) // 2 + + # Copy the window to the center of the white square + padded_window[y_center:y_center+window_height, x_center:x_center+window_width] = rolling_window + + crops_list.append(rolling_window) + padded_crops_list.append(padded_window) + + crops_info_list.append((window_x_start, window_y_start, window_width, window_height)) + return crops_list, padded_crops_list, crops_info_list + + +def postprocess(window_borders, scores, crops_info_list, window_size=512): + bb_list = [] + scores_list = [] + + for i in range(len(window_borders)): + window_border = window_borders[i] + score = scores[i] + window_x_start, window_y_start, window_width, window_height = crops_info_list[i] + for k in range(len(window_border)): + + x0 = window_x_start+(window_border[k][0]-(window_size-window_width)//2) + y0 = window_y_start+(window_border[k][1]-(window_size-window_height)//2) + x1 = window_x_start+(window_border[k][2]-(window_size-window_width)//2) + y1 = window_y_start+(window_border[k][3]-(window_size-window_height)//2) + + bb_list.append([x0, y0, x1, y1]) + scores_list.append(score[k]) + return bb_list, scores_list + +if __name__ == "__main__": + print("hello world") \ No newline at end of file