diff --git a/HybridViT/beam.py b/HybridViT/beam.py
new file mode 100644
index 0000000000000000000000000000000000000000..c91a9d64105d6b2eebee9b4fd10542c2cc8d8ee1
--- /dev/null
+++ b/HybridViT/beam.py
@@ -0,0 +1,131 @@
+import torch
+from typing import List
+from einops import rearrange, repeat
+from typing import Optional
+
+class Hypothesis:
+ seq: List[int]
+ score: float
+ attn_weights: List[float]
+
+ def __init__(
+ self,
+ seq_tensor: torch.LongTensor,
+ score: float,
+ weights: Optional[torch.FloatTensor] = None
+ ) -> None:
+ raw_seq = seq_tensor.tolist()
+
+ self.seq = raw_seq
+ self.score = score
+ if weights:
+ self.attn_weights = weights.tolist()
+ assert len(self.attn_weights) == len(self.seq)
+ else:
+ self.attn_weights = None
+
+ def __len__(self):
+ if len(self.seq) != 0:
+ return len(self.seq)
+ else:
+ return 1
+
+ def __str__(self):
+ return f"seq: {self.seq}, score: {self.score}, weight: {self.attn_weights}"
+
+
+class Beam:
+ def __init__(self,
+ start_w=1,
+ stop_w=2,
+ ignore_w=0,
+ max_len=150,
+ viz_attn=False,
+ device='cuda'
+ ):
+ self.stop_w = stop_w
+ self.start_w = start_w
+
+ self.hypotheses = torch.full(
+ (1, max_len + 2),
+ fill_value=ignore_w,
+ dtype=torch.long,
+ device=device,
+ )
+ if viz_attn:
+ self.hyp_alpha = torch.ones(1, max_len + 2, dtype=torch.float, device=device)
+
+ self.hypotheses[:, 0] = start_w
+ self.hyp_scores = torch.zeros(1, dtype=torch.float, device=device)
+ self.completed_hypotheses: List[Hypothesis] = []
+ self.device = device
+ self.viz_attn = viz_attn
+
+ def advance(self, next_log_probs, step, beam_size):
+ vocab_size = next_log_probs.shape[1]
+ live_hyp_num = beam_size - len(self.completed_hypotheses)
+ exp_hyp_scores = repeat(self.hyp_scores, "b -> b e", e=vocab_size)
+ continuous_hyp_scores = rearrange(exp_hyp_scores + next_log_probs, "b e -> (b e)")
+ top_cand_hyp_scores, top_cand_hyp_pos = torch.topk(
+ continuous_hyp_scores, k=live_hyp_num
+ )
+
+ prev_hyp_ids = top_cand_hyp_pos // vocab_size
+ hyp_word_ids = top_cand_hyp_pos % vocab_size
+
+ step += 1
+ new_hypotheses = []
+ new_hyp_scores = []
+
+ for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip(
+ prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores
+ ):
+ cand_new_hyp_score = cand_new_hyp_score.detach().item()
+ self.hypotheses[prev_hyp_id, step] = hyp_word_id
+
+ if hyp_word_id == self.stop_w:
+ self.completed_hypotheses.append(
+ Hypothesis(
+ seq_tensor=self.hypotheses[prev_hyp_id, 1:step+1]
+ .detach()
+ .clone(), # remove START_W at first
+ score=cand_new_hyp_score,
+ )
+ )
+ else:
+ new_hypotheses.append(self.hypotheses[prev_hyp_id].detach().clone())
+ new_hyp_scores.append(cand_new_hyp_score)
+
+ return new_hypotheses, new_hyp_scores
+
+ def get_incomplete_inds(self, hyp_word_ids):
+ return [ind for ind, next_word in enumerate(hyp_word_ids) if
+ next_word != self.stop_w]
+
+ def get_complete_inds(self, hyp_word_ids, incomplete_inds):
+ return list(set(range(len(hyp_word_ids))) - set(incomplete_inds))
+
+ def set_current_state(self, hypotheses):
+ "Set the outputs for the current timestep."
+ self.hypotheses = torch.stack(hypotheses, dim=0)
+ return
+
+ def set_current_score(self, hyp_scores):
+ "Set the scores for the current timestep."
+ self.hyp_scores = torch.tensor(
+ hyp_scores, dtype=torch.float, device=self.device
+ )
+ return
+
+ def done(self, beam_size):
+ return len(self.completed_hypotheses) == beam_size
+
+ def set_hypothesis(self):
+ if len(self.completed_hypotheses) == 0:
+ self.completed_hypotheses.append(
+ Hypothesis(
+ seq_tensor=self.hypotheses[0, 1:].detach().clone(),
+ score=self.hyp_scores[0].detach().item(),
+ )
+ )
+ return
\ No newline at end of file
diff --git a/HybridViT/helper.py b/HybridViT/helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba5cec3e1b2ca917cadbd914fbe45c0a6095f6b2
--- /dev/null
+++ b/HybridViT/helper.py
@@ -0,0 +1,182 @@
+import torch
+import random
+import numpy as np
+from PIL import Image
+from typing import Dict
+import torch.nn.functional as F
+import numpy as np
+from PIL import Image
+import cv2
+import math
+import albumentations as alb
+from albumentations.pytorch.transforms import ToTensorV2
+from collections import OrderedDict
+from itertools import repeat
+import collections.abc
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+
+to_3tuple = _ntuple(3)
+
+def clean_state_dict(state_dict):
+ # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
+ cleaned_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ name = k[7:] if k.startswith('module.') else k
+ cleaned_state_dict[name] = v
+ return cleaned_state_dict
+
+
+def math_transform(mean, std, is_gray: bool):
+ test_transform = []
+ normalize = [
+ alb.CLAHE(clip_limit=2, tile_grid_size=(2, 2), always_apply=True),
+ alb.Normalize(to_3tuple(mean), to_3tuple(std)),
+ ToTensorV2()
+ ]
+ if is_gray:
+ test_transform += [alb.ToGray(always_apply=True)]
+ test_transform += normalize
+
+ test_transform = alb.Compose([*test_transform])
+ return test_transform
+
+
+def pad(img: Image.Image, divable=32):
+ """Pad an Image to the next full divisible value of `divable`. Also normalizes the PIL.image and invert if needed.
+
+ Args:
+ img (PIL.Image): input PIL.image
+ divable (int, optional): . Defaults to 32.
+
+ Returns:
+ PIL.Image
+ """
+ data = np.array(img.convert('LA'))
+
+ data = (data-data.min())/(data.max()-data.min())*255
+ if data[..., 0].mean() > 128:
+ gray = 255*(data[..., 0] < 128).astype(np.uint8) # To invert the text to white
+ else:
+ gray = 255*(data[..., 0] > 128).astype(np.uint8)
+ data[..., 0] = 255-data[..., 0]
+
+ coords = cv2.findNonZero(gray) # Find all non-zero points (text)
+ a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
+ rect = data[b:b+h, a:a+w]
+
+ if rect[..., -1].var() == 0:
+ im = Image.fromarray((rect[..., 0]).astype(np.uint8)).convert('L')
+ else:
+ im = Image.fromarray((255-rect[..., -1]).astype(np.uint8)).convert('L')
+ dims = []
+
+ for x in [w, h]:
+ div, mod = divmod(x, divable)
+ dims.append(divable*(div + (1 if mod > 0 else 0)))
+
+ padded = Image.new('L', dims, 255)
+ padded.paste(im, im.getbbox())
+
+ return padded
+
+def get_divisible_size(ori_h, ori_w, max_dimension=None, scale_factor=32):
+ new_h, new_w = ori_h, ori_w
+ if ori_h % scale_factor:
+ new_h = math.ceil(ori_h/scale_factor)*scale_factor
+ if new_h > max_dimension[0]:
+ new_h = math.floor(ori_h/scale_factor)*scale_factor
+
+ if ori_w % scale_factor:
+ new_w = math.ceil(ori_w/scale_factor)*scale_factor
+ if new_w > max_dimension[1]:
+ new_w = math.floor(ori_w/scale_factor)*scale_factor
+
+ return int(new_h),int(new_w)
+
+def minmax_size(img, max_dimensions=None, min_dimensions=None, is_gray=True):
+ if max_dimensions is not None:
+ ratios = [a/b for a, b in zip(list(img.size)[::-1], max_dimensions)]
+ if any([r > 1 for r in ratios]):
+ size = np.array(img.size)/max(ratios)
+ new_h, new_w = get_divisible_size(size[1], size[0], max_dimensions)
+ img = img.resize((new_w, new_h), Image.LANCZOS)
+
+ if min_dimensions is not None:
+ ratios = [a/b for a, b in zip(list(img.size)[::-1], min_dimensions)]
+ if any([r < 1 for r in ratios]):
+ new_h, new_w = img.size[1] / min(ratios), img.size[0] / min(ratios)
+ new_h, new_w = get_divisible_size(new_h, new_w, max_dimensions)
+ if is_gray:
+ MODE = 'L'
+ BACKGROUND = 255
+ padded_im = Image.new(MODE, (new_w, new_h), BACKGROUND)
+ padded_im.paste(img, img.getbbox())
+ img = padded_im
+
+ return img
+
+def resize(resizer, img: Image.Image, opt: Dict):
+ # for math recognition problem image alway in grayscale mode
+ img = img.convert('L')
+ assert isinstance(opt, Dict)
+ assert "imgH" in opt
+ assert "imgW" in opt
+ expected_H = opt['imgH']
+
+ if expected_H is None:
+ max_dimensions = opt['max_dimension'] #can be bigger than max dim in training set
+ min_dimensions = opt['min_dimension']
+ #equal to min dim in trainign set
+ test_transform = math_transform(opt['mean'], opt['std'], not opt['rgb'])
+ try:
+ new_img = minmax_size(pad(img) if opt['pad'] else img, max_dimensions, min_dimensions, not opt['rgb'])
+
+ if not resizer:
+ new_img = np.asarray(new_img.convert('RGB')).astype('uint8')
+ new_img = test_transform(image=new_img)['image']
+ if not opt['rgb']: new_img = new_img[:1]
+ new_img = new_img.unsqueeze(0)
+ new_img = new_img.float()
+ else:
+ with torch.no_grad():
+ input_image = new_img.convert('RGB').copy()
+ r, w, h = 1, input_image.size[0], input_image.size[1]
+ for i in range(20):
+ h = int(h * r)
+ new_img = pad(minmax_size(input_image.resize((w, h), Image.BILINEAR if r > 1 else Image.LANCZOS),
+ max_dimensions,
+ min_dimensions,
+ not opt['rgb']
+ ))
+ t = test_transform(image=np.array(new_img.convert('RGB')).astype('uint8'))['image']
+ if not opt['rgb']: t = t[:1]
+ t = t.unsqueeze(0)
+ t = t.float()
+ w = (resizer(t.to(opt['device'])).argmax(-1).item()+1)*opt['min_dimension'][1]
+
+ if (w == new_img.size[0]):
+ break
+
+ r = w/new_img.size[0]
+
+ new_img = t
+ except ValueError as e:
+ print('Error:', e)
+ new_img = np.asarray(img.convert('RGB')).astype('uint8')
+ assert len(new_img.shape) == 3 and new_img.shape[2] == 3
+ new_img = test_transform(image=new_img)['image']
+ if not opt['rgb']: new_img = new_img[:1]
+ new_img = new_img.unsqueeze(0)
+ h, w = new_img.shape[2:]
+ new_img = F.pad(new_img, (0, max_dimensions[1]-w, 0, max_dimensions[0]-h), value=1)
+
+ assert len(new_img.shape) == 4, f'{new_img.shape}'
+ return new_img
diff --git a/HybridViT/module/component/common/__init__.py b/HybridViT/module/component/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a2e01462a4a223e12014ec012eb119fc6a6cd84
--- /dev/null
+++ b/HybridViT/module/component/common/__init__.py
@@ -0,0 +1,6 @@
+from .conv import *
+from .droppath import *
+from .gated_sum import *
+from .maxout import *
+from .postional_encoding import *
+from .mae_posembed import *
\ No newline at end of file
diff --git a/HybridViT/module/component/common/conv.py b/HybridViT/module/component/common/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6d88622aa7c435a61e14770f75464feaf7420b4
--- /dev/null
+++ b/HybridViT/module/component/common/conv.py
@@ -0,0 +1,148 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional
+import warnings
+
+__all__ = ['ConvMLP', 'ConvModule']
+
+class LayerNorm2d(nn.LayerNorm):
+ """ LayerNorm for channels of '2D' spatial BCHW tensors """
+ def __init__(self, num_channels):
+ super().__init__(num_channels)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.layer_norm(
+ x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
+
+class DepthwiseSeparableConv2d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
+ super(DepthwiseSeparableConv2d, self).__init__()
+ self.depthwise = nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ padding=padding,
+ stride=stride,
+ bias=bias,
+ groups=in_channels,
+ )
+ self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups, bias=bias)
+
+ def forward(self, x):
+ out = self.depthwise(x)
+ out = self.pointwise(out)
+ return out
+
+class ConvMLP(nn.Module):
+ def __init__(self, in_channels, out_channels=None, hidden_channels=None, drop=0.25):
+ super().__init__()
+ out_channels = in_channels or out_channels
+ hidden_channels = in_channels or hidden_channels
+ self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=True)
+ self.norm = LayerNorm2d(hidden_channels)
+ self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=True)
+ self.act = nn.ReLU()
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.norm(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ return x
+
+class ConvModule(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias='auto',
+ conv_layer:Optional[nn.Module]=nn.Conv2d,
+ norm_layer:Optional[nn.Module]=nn.BatchNorm2d,
+ act_layer:Optional[nn.Module]=nn.ReLU,
+ inplace=True,
+ with_spectral_norm=False,
+ padding_mode='zeros',
+ order=('conv', 'norm', 'act')
+ ):
+ official_padding_mode = ['zeros', 'circular']
+ nonofficial_padding_mode = dict(zero=nn.ZeroPad2d, reflect=nn.ReflectionPad2d, replicate=nn.ReplicationPad2d)
+ self.with_spectral_norm = with_spectral_norm
+ self.with_explicit_padding = padding_mode not in official_padding_mode
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 3
+ assert set(order) == set(['conv', 'norm', 'act'])
+
+ self.with_norm = norm_layer is not None
+ self.with_act = act_layer is not None
+
+ if bias == 'auto':
+ bias = not self.with_norm
+ self.with_bias = bias
+
+ if self.with_norm and self.with_bias:
+ warnings.warn('ConvModule has norm and bias at the same time')
+
+ if self.with_explicit_padding:
+ assert padding_mode in list(nonofficial_padding_mode), "Not implemented padding algorithm"
+ self.padding_layer = nonofficial_padding_mode[padding_mode]
+
+ # reset padding to 0 for conv module
+ conv_padding = 0 if self.with_explicit_padding else padding
+
+ self.conv = conv_layer(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=conv_padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias
+ )
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+
+ if self.with_spectral_norm:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ if order.index('norm') > order.index('conv'):
+ norm_channels = out_channels
+ else:
+ norm_channels = in_channels
+ self.norm = norm_layer(norm_channels)
+
+ if self.with_act:
+ if act_layer not in [nn.Tanh, nn.PReLU, nn.Sigmoid]:
+ self.act = act_layer()
+ else:
+ self.act = act_layer(inplace=inplace)
+
+ def forward(self, x, activate=True, norm=True):
+ for layer in self.order:
+ if layer == 'conv':
+ if self.with_explicit_padding:
+ x = self.padding_layer(x)
+ x = self.conv(x)
+ elif layer == 'norm' and norm and self.with_norm:
+ x = self.norm(x)
+ elif layer == 'act' and activate and self.with_act:
+ x = self.act(x)
+ return x
\ No newline at end of file
diff --git a/HybridViT/module/component/common/droppath.py b/HybridViT/module/component/common/droppath.py
new file mode 100644
index 0000000000000000000000000000000000000000..e54a07c886136353a1f0a47c6c96962f8a28422c
--- /dev/null
+++ b/HybridViT/module/component/common/droppath.py
@@ -0,0 +1,36 @@
+from torch import nn
+
+__all__ = ['DropPath']
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
\ No newline at end of file
diff --git a/HybridViT/module/component/common/gated_sum.py b/HybridViT/module/component/common/gated_sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..387ef11bbe1f1d90a91a5a44fc4c517235d6add7
--- /dev/null
+++ b/HybridViT/module/component/common/gated_sum.py
@@ -0,0 +1,36 @@
+import torch
+
+
+class GatedSum(torch.nn.Module):
+ """
+ This `Module` represents a gated sum of two tensors `a` and `b`. Specifically:
+ ```
+ f = activation(W [a; b])
+ out = f * a + (1 - f) * b
+ ```
+ # Parameters
+ input_dim : `int`, required
+ The dimensionality of the input. We assume the input have shape `(..., input_dim)`.
+ activation : `Activation`, optional (default = `torch.nn.Sigmoid()`)
+ The activation function to use.
+ """
+
+ def __init__(self, input_dim: int, activation = torch.nn.Sigmoid()) -> None:
+ super().__init__()
+ self.input_dim = input_dim
+ self._gate = torch.nn.Linear(input_dim * 2, 1)
+ self._activation = activation
+
+ def get_input_dim(self):
+ return self.input_dim
+
+ def get_output_dim(self):
+ return self.input_dim
+
+ def forward(self, input_a: torch.Tensor, input_b: torch.Tensor) -> torch.Tensor:
+ if input_a.size() != input_b.size():
+ raise ValueError("The input must have the same size.")
+ if input_a.size(-1) != self.input_dim:
+ raise ValueError("Input size must match `input_dim`.")
+ gate_value = self._activation(self._gate(torch.cat([input_a, input_b], -1)))
+ return gate_value * input_a + (1 - gate_value) * input_b
\ No newline at end of file
diff --git a/HybridViT/module/component/common/mae_posembed.py b/HybridViT/module/component/common/mae_posembed.py
new file mode 100644
index 0000000000000000000000000000000000000000..187ecd981c59df11794df9d5be8c02d8a37e04d9
--- /dev/null
+++ b/HybridViT/module/component/common/mae_posembed.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# Position embedding utils
+# --------------------------------------------------------
+
+import numpy as np
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size_H, grid_size_W, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size_H, dtype=np.float32)
+ grid_w = np.arange(grid_size_W, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size_H, grid_size_W])
+
+ print('new grid shape', grid.shape)
+
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+if __name__ == '__main__':
+ pos_embed = get_2d_sincos_pos_embed(256, 800, 800, True)
+ print(pos_embed.shape)
\ No newline at end of file
diff --git a/HybridViT/module/component/common/maxout.py b/HybridViT/module/component/common/maxout.py
new file mode 100644
index 0000000000000000000000000000000000000000..07e40b8fb96dfe09dd9ecdd04eb5094a84e2c20b
--- /dev/null
+++ b/HybridViT/module/component/common/maxout.py
@@ -0,0 +1,22 @@
+from torch import nn
+
+class Maxout(nn.Module):
+ """
+ Maxout makes pools from the last dimension and keeps only the maximum value from
+ each pool.
+ """
+
+ def __init__(self, pool_size):
+ """
+ Args:
+ pool_size (int): Number of elements per pool
+ """
+ super(Maxout, self).__init__()
+ self.pool_size = pool_size
+
+ def forward(self, x):
+ [*shape, last] = x.size()
+ out = x.view(*shape, last // self.pool_size, self.pool_size)
+ out, _ = out.max(-1)
+ return out
+
diff --git a/HybridViT/module/component/common/postional_encoding.py b/HybridViT/module/component/common/postional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ced84f7bf616a8b3deae2650fba02b9adfff2d3
--- /dev/null
+++ b/HybridViT/module/component/common/postional_encoding.py
@@ -0,0 +1,226 @@
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+from typing import Tuple
+from torch import Tensor
+from torch.nn import functional as F
+
+
+class Adaptive2DPositionalEncoding(nn.Module):
+ """Implement Adaptive 2D positional encoder for SATRN, see
+ `SATRN `_
+ Modified from https://github.com/Media-Smart/vedastr
+ Licensed under the Apache License, Version 2.0 (the "License");
+ Args:
+ d_hid (int): Dimensions of hidden layer.
+ n_height (int): Max height of the 2D feature output.
+ n_width (int): Max width of the 2D feature output.
+ dropout (int): Size of hidden layers of the model.
+ """
+
+ def __init__(self,
+ d_hid=512,
+ n_height=100,
+ n_width=100,
+ dropout=0.1,
+ ):
+ super().__init__()
+
+ h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid)
+ h_position_encoder = h_position_encoder.transpose(0, 1)
+ h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1)
+
+ w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid)
+ w_position_encoder = w_position_encoder.transpose(0, 1)
+ w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width)
+
+ self.register_buffer('h_position_encoder', h_position_encoder)
+ self.register_buffer('w_position_encoder', w_position_encoder)
+
+ self.h_scale = self.scale_factor_generate(d_hid)
+ self.w_scale = self.scale_factor_generate(d_hid)
+ self.pool = nn.AdaptiveAvgPool2d(1)
+ self.dropout = nn.Dropout(p=dropout)
+
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
+ """Sinusoid position encoding table."""
+ denominator = torch.Tensor([
+ 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
+ for hid_j in range(d_hid)
+ ])
+ denominator = denominator.view(1, -1)
+ pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
+ sinusoid_table = pos_tensor * denominator
+ sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
+ sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
+
+ return sinusoid_table
+
+ def scale_factor_generate(self, d_hid):
+ scale_factor = nn.Sequential(
+ nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True),
+ nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid())
+
+ return scale_factor
+
+ def init_weight(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('ReLU'))
+
+ def forward(self, x):
+ b, c, h, w = x.size()
+
+ avg_pool = self.pool(x)
+
+ h_pos_encoding = \
+ self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
+ w_pos_encoding = \
+ self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
+
+ out = x + h_pos_encoding + w_pos_encoding
+
+ out = self.dropout(out)
+
+ return out
+
+class PositionalEncoding2D(nn.Module):
+ """2-D positional encodings for the feature maps produced by the encoder.
+ Following https://arxiv.org/abs/2103.06450 by Sumeet Singh.
+ Reference:
+ https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2021-labs/blob/main/lab9/text_recognizer/models/transformer_util.py
+ """
+
+ def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000) -> None:
+ super().__init__()
+ self.d_model = d_model
+ assert d_model % 2 == 0, f"Embedding depth {d_model} is not even"
+ pe = self.make_pe(d_model, max_h, max_w) # (d_model, max_h, max_w)
+ self.register_buffer("pe", pe)
+
+ @staticmethod
+ def make_pe(d_model: int, max_h: int, max_w: int) -> Tensor:
+ """Compute positional encoding."""
+ pe_h = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2)
+ pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w)
+
+ pe_w = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2)
+ pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w)
+
+ pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w)
+ return pe
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+ Args:
+ x: (B, d_model, H, W)
+ Returns:
+ (B, d_model, H, W)
+ """
+ assert x.shape[1] == self.pe.shape[0] # type: ignore
+ x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore
+ return x
+
+
+class PositionalEncoding1D(nn.Module):
+ """Classic Attention-is-all-you-need positional encoding."""
+
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000) -> None:
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout)
+ pe = self.make_pe(d_model, max_len) # (max_len, 1, d_model)
+ self.register_buffer("pe", pe)
+
+ @staticmethod
+ def make_pe(d_model: int, max_len: int) -> Tensor:
+ """Compute positional encoding."""
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(1)
+ return pe
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+ Args:
+ x: (S, B, d_model)
+ Returns:
+ (S, B, d_model)
+ """
+ assert x.shape[2] == self.pe.shape[2] # type: ignore
+ x = x + self.pe[: x.size(0)] # type: ignore
+ return self.dropout(x)
+
+Size_ = Tuple[int, int]
+
+class PosConv(nn.Module):
+ # PEG from https://arxiv.org/abs/2102.10882
+ def __init__(self, in_chans, embed_dim=768, stride=1):
+ super(PosConv, self).__init__()
+ self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), )
+ self.stride = stride
+
+ def forward(self, x, size: Size_):
+ B, N, C = x.shape
+ cls_token, feat_token = x[:, 0], x[:, 1:]
+ cnn_feat_token = feat_token.transpose(1, 2).view(B, C, *size)
+ x = self.proj(cnn_feat_token)
+ if self.stride == 1:
+ x += cnn_feat_token
+ x = x.flatten(2).transpose(1, 2)
+ x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
+ return x
+
+ def no_weight_decay(self):
+ return ['proj.%d.weight' % i for i in range(4)]
+
+class PosConv1D(nn.Module):
+ # PEG from https://arxiv.org/abs/2102.10882
+ def __init__(self, in_chans, embed_dim=768, stride=1):
+ super(PosConv1D, self).__init__()
+ self.proj = nn.Sequential(nn.Conv1d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), )
+ self.stride = stride
+
+ def forward(self, x, size: int):
+ B, N, C = x.shape
+ cls_token, feat_token = x[:, 0], x[:, 1:]
+ cnn_feat_token = feat_token.transpose(1, 2).view(B, C, size)
+ x = self.proj(cnn_feat_token)
+ if self.stride == 1:
+ x += cnn_feat_token
+ x = x.transpose(1, 2)
+ x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
+ return x
+
+ def no_weight_decay(self):
+ return ['proj.%d.weight' % i for i in range(4)]
+
+def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), old_grid_shape=()):
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
+
+ print('Resized position embedding: %s to %s'%(posemb.shape, posemb_new.shape))
+ ntok_new = posemb_new.shape[1]
+
+ if num_tokens:
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
+ ntok_new -= num_tokens
+ else:
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
+
+ if not len(gs_new): # backwards compatibility
+ gs_new = [int(math.sqrt(ntok_new))] * 2
+
+ assert len(gs_new) >= 2
+
+ print('Position embedding grid-size from %s to %s'%(old_grid_shape, gs_new))
+ posemb_grid = posemb_grid.reshape(1, old_grid_shape[0], old_grid_shape[1], -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
diff --git a/HybridViT/module/component/feature_extractor/__init__.py b/HybridViT/module/component/feature_extractor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/HybridViT/module/component/feature_extractor/addon_module/__init__.py b/HybridViT/module/component/feature_extractor/addon_module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5809d2e26a7c602ccdbb5d41379a5e15d395009
--- /dev/null
+++ b/HybridViT/module/component/feature_extractor/addon_module/__init__.py
@@ -0,0 +1,2 @@
+from .aspp import *
+from .visual_attention import *
\ No newline at end of file
diff --git a/HybridViT/module/component/feature_extractor/addon_module/aspp.py b/HybridViT/module/component/feature_extractor/addon_module/aspp.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3d333e8c4064a8f31007ad6b0d125e3c277bfad
--- /dev/null
+++ b/HybridViT/module/component/feature_extractor/addon_module/aspp.py
@@ -0,0 +1,59 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ['ASPP']
+
+class ASPPModule(nn.Module):
+ def __init__(self, inplanes, planes, kernel_size, padding, dilation):
+ super(ASPPModule, self).__init__()
+ self.atrous_conv = nn.Conv2d(
+ inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False
+ )
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x): # skipcq: PYL-W0221
+ x = self.atrous_conv(x)
+ x = self.relu(x)
+ return x
+
+class ASPP(nn.Module):
+ def __init__(self, inplanes: int, output_stride: int, output_features: int, dropout=0.5):
+ super(ASPP, self).__init__()
+
+ if output_stride == 32:
+ dilations = [1, 3, 6, 9]
+ elif output_stride == 16:
+ dilations = [1, 6, 12, 18]
+ elif output_stride == 8:
+ dilations = [1, 12, 24, 36]
+ else:
+ raise NotImplementedError
+
+ self.aspp1 = ASPPModule(inplanes, output_features, 1, padding=0, dilation=dilations[0])
+ self.aspp2 = ASPPModule(inplanes, output_features, 3, padding=dilations[1], dilation=dilations[1])
+ self.aspp3 = ASPPModule(inplanes, output_features, 3, padding=dilations[2], dilation=dilations[2])
+ self.aspp4 = ASPPModule(inplanes, output_features, 3, padding=dilations[3], dilation=dilations[3])
+
+ self.global_avg_pool = nn.Sequential(
+ nn.AdaptiveAvgPool2d((1, 1)),
+ nn.Conv2d(inplanes, output_features, 1, stride=1, bias=False),
+ nn.ReLU(inplace=True),
+ )
+ self.conv1 = nn.Conv2d(output_features * 5, output_features, 1, bias=False)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x): # skipcq: PYL-W0221
+ x1 = self.aspp1(x)
+ x2 = self.aspp2(x)
+ x3 = self.aspp3(x)
+ x4 = self.aspp4(x)
+ x5 = self.global_avg_pool(x)
+ x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=False)
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
+
+ x = self.conv1(x)
+ x = self.relu1(x)
+
+ return self.dropout(x)
\ No newline at end of file
diff --git a/HybridViT/module/component/feature_extractor/addon_module/visual_attention.py b/HybridViT/module/component/feature_extractor/addon_module/visual_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..16e4ed2c0e294a4794c75d53a85c2523de40901d
--- /dev/null
+++ b/HybridViT/module/component/feature_extractor/addon_module/visual_attention.py
@@ -0,0 +1,325 @@
+import torch.nn as nn
+import torch
+from torch.nn import functional as F
+from functools import reduce
+
+__all__ = ['Adaptive_Global_Model', 'GlobalContext', 'SELayer', 'SKBlock', 'CBAM']
+
+
+def constant_init(module, val, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def xavier_init(module, gain=1, bias=0, distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if distribution == 'uniform':
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def normal_init(module, mean=0, std=1, bias=0):
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def uniform_init(module, a=0, b=1, bias=0):
+ nn.init.uniform_(module.weight, a, b)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def kaiming_init(module,
+ a=0,
+ mode='fan_out',
+ nonlinearity='relu',
+ bias=0,
+ distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if distribution == 'uniform':
+ nn.init.kaiming_uniform_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ else:
+ nn.init.kaiming_normal_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def last_zero_init(m):
+ if isinstance(m, nn.Sequential):
+ constant_init(m[-1], val=0)
+ else:
+ constant_init(m, val=0)
+
+
+class SELayer(nn.Module):
+ def __init__(self, channel, reduction=16, dropout=0.1):
+ super(SELayer, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction),
+ nn.ReLU(inplace=True),
+ nn.Dropout(dropout),
+ nn.Linear(channel // reduction, channel),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x + x * y.expand_as(x) # SE-Residual
+
+
+class BasicConv(nn.Module):
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
+ bn=True, bias=False):
+ super(BasicConv, self).__init__()
+ self.out_channels = out_planes
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, groups=groups, bias=bias)
+ self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
+ self.relu = nn.ReLU() if relu else None
+
+ def forward(self, x):
+ x = self.conv(x)
+ if self.bn is not None:
+ x = self.bn(x)
+ if self.relu is not None:
+ x = self.relu(x)
+ return x
+
+
+class Flatten(nn.Module):
+ def forward(self, x):
+ return x.view(x.size(0), -1)
+
+
+class ChannelGate(nn.Module):
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
+ super(ChannelGate, self).__init__()
+ self.gate_channels = gate_channels
+ self.mlp = nn.Sequential(
+ Flatten(),
+ nn.Linear(gate_channels, gate_channels // reduction_ratio),
+ nn.ReLU(),
+ nn.Linear(gate_channels // reduction_ratio, gate_channels)
+ )
+ self.pool_types = pool_types
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
+ self.maxpool = nn.AdaptiveMaxPool2d(1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ channel_att_sum = None
+ for pool_type in self.pool_types:
+ if pool_type == 'avg':
+ avg_pool = self.avgpool(x)
+ channel_att_raw = self.mlp(avg_pool)
+ elif pool_type == 'max':
+ max_pool = self.maxpool(x)
+ channel_att_raw = self.mlp(max_pool)
+
+ if channel_att_sum is None:
+ channel_att_sum = channel_att_raw
+ else:
+ channel_att_sum = channel_att_sum + channel_att_raw
+
+ scale = self.sigmoid(channel_att_sum).unsqueeze(2) \
+ .unsqueeze(3).expand_as(x)
+ return x * scale
+
+
+class ChannelPool(nn.Module):
+ def forward(self, x):
+ return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1) \
+ .unsqueeze(1)), dim=1)
+
+
+class SpatialGate(nn.Module):
+ def __init__(self):
+ super(SpatialGate, self).__init__()
+ kernel_size = 7
+ self.compress = ChannelPool()
+ self.spatial = BasicConv(2, 1, kernel_size, stride=1 \
+ , padding=(kernel_size - 1) // 2, relu=False)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ x_compress = self.compress(x)
+ x_out = self.spatial(x_compress)
+ scale = self.sigmoid(x_out) # broadcasting
+ return x * scale
+
+
+class CBAM(nn.Module):
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
+ super(CBAM, self).__init__()
+ self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
+ self.no_spatial = no_spatial
+ if not no_spatial:
+ self.SpatialGate = SpatialGate()
+
+ def forward(self, x):
+ x_out = self.ChannelGate(x)
+ if not self.no_spatial:
+ x_out = self.SpatialGate(x_out)
+ return x_out
+
+
+class SKBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, stride=1, M=2, r=16, L=32):
+ super(SKBlock, self).__init__()
+ d = max(in_channels // r, L)
+ self.M = M
+ self.out_channels = out_channels
+ self.conv = nn.ModuleList()
+ for i in range(M):
+ self.conv.append(nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, 3, stride, padding=1 + i, dilation=1 + i, groups=32, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True)))
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc1 = nn.Sequential(nn.Conv2d(out_channels, d, 1, bias=False),
+ nn.BatchNorm2d(d),
+ nn.ReLU(inplace=True))
+ self.fc2 = nn.Conv2d(d, out_channels * M, 1, 1, bias=False)
+ self.softmax = nn.Softmax(dim=1)
+
+ def forward(self, input):
+ batch_size = input.size(0)
+ output = []
+ # the part of split
+ for i, conv in enumerate(self.conv):
+ # print(i,conv(input).size())
+ output.append(conv(input))
+ # the part of fusion
+ U = reduce(lambda x, y: x + y, output)
+ s = self.global_pool(U)
+ z = self.fc1(s)
+ a_b = self.fc2(z)
+ a_b = a_b.reshape(batch_size, self.M, self.out_channels, -1)
+ a_b = self.softmax(a_b)
+ # the part of selection
+ a_b = list(a_b.chunk(self.M, dim=1)) # split to a and b
+ a_b = list(map(lambda x: x.reshape(batch_size, self.out_channels, 1, 1), a_b))
+ V = list(map(lambda x, y: x * y, output, a_b))
+ V = reduce(lambda x, y: x + y, V)
+ return V
+
+def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
+ min_value = min_value or divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < round_limit * v:
+ new_v += divisor
+ return new_v
+
+class LayerNorm2d(nn.LayerNorm):
+ """ LayerNorm for channels of '2D' spatial BCHW tensors """
+ def __init__(self, num_channels):
+ super().__init__(num_channels)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.layer_norm(
+ x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
+
+class ConvMLP(nn.Module):
+ def __init__(self, in_channels, out_channels=None, hidden_channels=None, drop=0.25):
+ super().__init__()
+ out_channels = in_channels or out_channels
+ hidden_channels = in_channels or hidden_channels
+ self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=True)
+ self.norm = LayerNorm2d(hidden_channels)
+ self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=True)
+ self.act = nn.ReLU()
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.norm(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ return x
+
+class GlobalContext(nn.Module):
+ def __init__(self,
+ channel,
+ use_attn=True,
+ fuse_add=True,
+ fuse_scale=False,
+ rd_ratio=1./8,
+ rd_channels=None
+ ):
+ super().__init__()
+ self.use_attn = use_attn
+ self.global_cxt = nn.Conv2d(channel, 1, kernel_size=1, bias=True) if use_attn else nn.AdaptiveAvgPool2d(1)
+
+ if rd_channels is None:
+ rd_channels = make_divisible(channel*rd_ratio, divisor=1, round_limit=0.)
+
+ if fuse_add:
+ self.bottleneck_add = ConvMLP(channel, hidden_channels=rd_channels)
+ else:
+ self.bottleneck_add = None
+ if fuse_scale:
+ self.bottleneck_mul = ConvMLP(channel, hidden_channels=rd_channels)
+ else:
+ self.bottleneck_mul = None
+
+ self.init_weight()
+
+ def init_weight(self):
+ if self.use_attn:
+ nn.init.kaiming_normal_(self.global_cxt.weight, mode='fan_in', nonlinearity='relu')
+ if self.bottleneck_add is not None:
+ nn.init.zeros_(self.bottleneck_add.fc2.weight)
+ if self.bottleneck_mul is not None:
+ nn.init.zeros_(self.bottleneck_mul.fc2.weight)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ if self.use_attn:
+ attn = self.global_cxt(x).reshape(B, 1, H*W).squeeze(1)
+ attn = F.softmax(attn, dim=-1).unsqueeze(-1) #shape BxH*Wx1
+ query = x.reshape(B, C, H*W) #shape BxCxH*W
+ glob_cxt = torch.bmm(query, attn).unsqueeze(-1)
+ else:
+ glob_cxt = self.global_cxt(x)
+ assert len(glob_cxt.shape) == 4
+
+ if self.bottleneck_add is not None:
+ x_trans = self.bottleneck_add(glob_cxt)
+ x_fuse = x + x_trans
+ if self.bottleneck_mul is not None:
+ x_trans = F.sigmoid(self.bottleneck_mul(glob_cxt))
+ x_fuse = x*x_trans
+
+ return x_fuse
+
+
+class Adaptive_Global_Model(nn.Module):
+ def __init__(self, inplanes, factor=2, ratio=0.0625, dropout=0.1):
+ super(Adaptive_Global_Model, self).__init__()
+ # b, w, h, c => gc_block (b, w, h, c) => => b, w, inplanes
+ self.embedding = nn.Linear(inplanes * factor, inplanes)
+ self.gc_block = GlobalContext(inplanes, ratio=ratio) #
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ x = self.gc_block(x) # BCHW => BCHW
+ x = x.permute(0, 3, 1, 2) # BCHW => BWCH
+ b, w, _, _ = x.shape
+ x = x.contiguous().view(b, w, -1)
+ x = self.embedding(x) # B W C
+ x = self.dropout(x)
+ return x
+
diff --git a/HybridViT/module/component/feature_extractor/clova_impl/__init__.py b/HybridViT/module/component/feature_extractor/clova_impl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..18f311768a9ed4ca490cac0d96b9adcccd4946ac
--- /dev/null
+++ b/HybridViT/module/component/feature_extractor/clova_impl/__init__.py
@@ -0,0 +1,2 @@
+from .resnet import ResNet_FeatureExtractor
+from .vgg import VGG_FeatureExtractor
\ No newline at end of file
diff --git a/HybridViT/module/component/feature_extractor/clova_impl/resnet.py b/HybridViT/module/component/feature_extractor/clova_impl/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e32faa3bef418369ada99770c3bbc938fb1c0d8a
--- /dev/null
+++ b/HybridViT/module/component/feature_extractor/clova_impl/resnet.py
@@ -0,0 +1,262 @@
+from typing import Dict
+from collections import OrderedDict
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ..addon_module.visual_attention import GlobalContext
+from .....helper import clean_state_dict
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = self._conv3x3(inplanes, planes)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = self._conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def _conv3x3(self, in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.bn2.weight)
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+class ResNet(nn.Module):
+ def __init__(self, input_channel, output_channel, block, layers, with_gcb=True, debug=False, zero_init_last_bn=False):
+ super(ResNet, self).__init__()
+ self.with_gcb = with_gcb
+
+ self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
+ self.inplanes = int(output_channel / 8)
+
+ self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
+ kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
+
+ self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
+ kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn0_2 = nn.BatchNorm2d(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+ self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
+ self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
+ 0], kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
+
+ self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+ self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
+ self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
+ 1], kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
+
+ self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
+ self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
+ self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
+ 2], kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
+
+ self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
+
+ self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
+ 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
+ self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
+
+ self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
+ 3], kernel_size=2, stride=1, padding=0, bias=False)
+ self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
+
+ self.init_weights(zero_init_last_bn=zero_init_last_bn)
+ self.debug = debug
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.bn4_2.weight)
+
+ def init_weights(self, zero_init_last_bn=True):
+ initialized = ['global_cxt', 'bottleneck_add', 'bottleneck_mul']
+ for n, m in self.named_modules():
+ if any([d in n for d in initialized]):
+ continue
+ elif isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ if zero_init_last_bn:
+ for m in self.modules():
+ if hasattr(m, 'zero_init_last_bn'):
+ m.zero_init_last_bn()
+
+ def _make_layer(self, block, planes, blocks, with_gcb=False, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ if self.with_gcb:
+ layers.append(GlobalContext(planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.debug:
+ print('input shape', x.shape)
+
+ x = self.conv0_1(x)
+ x = self.bn0_1(x)
+ x = self.relu(x)
+
+ if self.debug:
+ print('conv1 shape', x.shape)
+
+ x = self.conv0_2(x)
+ x = self.bn0_2(x)
+ x = self.relu(x)
+
+ if self.debug:
+ print('conv2 shape', x.shape)
+
+ x = self.maxpool1(x)
+
+ if self.debug:
+ print('pool1 shape', x.shape)
+
+ x = self.layer1(x)
+
+ if self.debug:
+ print('block1 shape', x.shape)
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ if self.debug:
+ print('conv3 shape', x.shape)
+
+ x = self.maxpool2(x)
+
+ if self.debug:
+ print('pool2 shape', x.shape)
+
+ x = self.layer2(x)
+
+ if self.debug:
+ print('block2 shape', x.shape)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+
+ if self.debug:
+ print('conv4 shape', x.shape)
+
+ x = self.maxpool3(x)
+
+ if self.debug:
+ print('pool3 shape', x.shape)
+
+ x = self.layer3(x)
+
+ if self.debug:
+ print('block3 shape', x.shape)
+
+ x = self.conv3(x)
+ x = self.bn3(x)
+ x = self.relu(x)
+
+ if self.debug:
+ print('conv5 shape', x.shape)
+
+ x = self.layer4(x)
+
+ if self.debug:
+ print('block4 shape', x.shape)
+
+ x = self.conv4_1(x)
+ x = self.bn4_1(x)
+ x = self.relu(x)
+
+ if self.debug:
+ print('conv6 shape', x.shape)
+
+ x = self.conv4_2(x)
+ x = self.bn4_2(x)
+ x = self.relu(x)
+
+ if self.debug:
+ print('conv7 shape', x.shape)
+
+ return x
+
+class ResNet_FeatureExtractor(nn.Module):
+ """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
+
+ def __init__(self, input_channel=3, output_channel=512, gcb=False, pretrained=False, weight_dir=None, debug=False):
+ super(ResNet_FeatureExtractor, self).__init__()
+ self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3], gcb, debug)
+ self.in_chans = input_channel
+ if pretrained:
+ assert weight_dir is not None
+ self.load_pretrained(weight_dir)
+
+ def forward(self, input):
+ output = self.ConvNet(input)
+ return output
+
+ def load_pretrained(self, weight_dir):
+ state_dict: OrderedDict = torch.load(weight_dir)
+ cleaned_state_dict = clean_state_dict(state_dict)
+ new_state_dict = OrderedDict()
+ name: str
+ param: torch.FloatTensor
+ for name, param in cleaned_state_dict.items():
+ if name.startswith('FeatureExtraction'):
+ output_name = name.replace('FeatureExtraction.', '')
+ if output_name == 'ConvNet.conv0_1.weight':
+ print('Old', param.shape)
+ new_param = param.repeat(1, self.in_chans, 1, 1)
+ print('New', new_param.shape)
+ else: new_param = param
+ new_state_dict[output_name] = new_param
+ print("=> Loading pretrained weight for ResNet backbone")
+ self.load_state_dict(new_state_dict)
+
+if __name__ == '__main__':
+ model = ResNet_FeatureExtractor(input_channel=1, debug=True)
+ a = torch.rand(1, 1, 128, 480)
+ output = model(a)
+ print(output.shape)
\ No newline at end of file
diff --git a/HybridViT/module/component/feature_extractor/clova_impl/vgg.py b/HybridViT/module/component/feature_extractor/clova_impl/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..beb924bd0d407a0e181fdfb5c9d2231a6dddfa48
--- /dev/null
+++ b/HybridViT/module/component/feature_extractor/clova_impl/vgg.py
@@ -0,0 +1,27 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+class VGG_FeatureExtractor(nn.Module):
+ """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """
+
+ def __init__(self, input_channel, output_channel=512):
+ super(VGG_FeatureExtractor, self).__init__()
+ self.output_channel = [int(output_channel / 8), int(output_channel / 4),
+ int(output_channel / 2), output_channel] # [64, 128, 256, 512]
+ self.ConvNet = nn.Sequential(
+ nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
+ nn.MaxPool2d(2, 2), # 64x16x50
+ nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True),
+ nn.MaxPool2d(2, 2), # 128x8x25
+ nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25
+ nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
+ nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
+ nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False),
+ nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25
+ nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False),
+ nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
+ nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
+ nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24
+
+ def forward(self, input):
+ return self.ConvNet(input)
diff --git a/HybridViT/module/component/feature_extractor/helpers.py b/HybridViT/module/component/feature_extractor/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f23da247054997e2d37899ce4133346b68d53933
--- /dev/null
+++ b/HybridViT/module/component/feature_extractor/helpers.py
@@ -0,0 +1,76 @@
+import math
+from typing import List, Tuple
+
+import torch.nn.functional as F
+
+# Calculate symmetric padding for a convolution
+def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
+ return padding
+
+
+# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
+def get_same_padding(x: int, k: int, s: int, d: int):
+ return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
+
+
+# Can SAME padding for given args be done statically?
+def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
+ return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
+
+
+# Dynamically pad input x with 'SAME' padding for conv with specified args
+def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
+ ih, iw = x.size()[-2:]
+ pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
+ return x
+
+
+def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
+ dynamic = False
+ if isinstance(padding, str):
+ # for any string padding, the padding will be calculated for you, one of three ways
+ padding = padding.lower()
+ if padding == 'same':
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
+ if is_static_pad(kernel_size, **kwargs):
+ # static case, no extra overhead
+ padding = get_padding(kernel_size, **kwargs)
+ else:
+ # dynamic 'SAME' padding, has runtime/GPU memory overhead
+ padding = 0
+ dynamic = True
+ elif padding == 'valid':
+ # 'VALID' padding, same as padding=0
+ padding = 0
+ else:
+ # Default to PyTorch style 'same'-ish symmetric padding
+ padding = get_padding(kernel_size, **kwargs)
+ return padding, dynamic
+
+
+def adapt_input_conv(in_chans, conv_weight):
+ conv_type = conv_weight.dtype
+ conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
+ O, I, J, K = conv_weight.shape
+ if in_chans == 1:
+ if I > 3:
+ assert conv_weight.shape[1] % 3 == 0
+ # For models with space2depth stems
+ conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
+ conv_weight = conv_weight.sum(dim=2, keepdim=False)
+ else:
+ conv_weight = conv_weight.sum(dim=1, keepdim=True)
+ elif in_chans != 3:
+ if I != 3:
+ raise NotImplementedError('Weight format not supported by conversion.')
+ else:
+ # NOTE this strategy should be better than random init, but there could be other combinations of
+ # the original RGB input layer weights that'd work better for specific cases.
+ repeat = int(math.ceil(in_chans / 3))
+ conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
+ conv_weight *= (3 / float(in_chans))
+ conv_weight = conv_weight.to(conv_type)
+ return conv_weight
\ No newline at end of file
diff --git a/HybridViT/module/component/feature_extractor/vgg.py b/HybridViT/module/component/feature_extractor/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..980dac3ad05dc48656e0eaac54606f1be7efd445
--- /dev/null
+++ b/HybridViT/module/component/feature_extractor/vgg.py
@@ -0,0 +1,96 @@
+import torch
+import torch.nn as nn
+
+__all__ = ['vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn']
+
+
+class VGG(nn.Module):
+
+ def __init__(self, features, num_channel_out=512, init_weights=True):
+ super(VGG, self).__init__()
+ self.features = features
+ self.num_out_features = 512
+
+ self.lastlayer = nn.Sequential(
+ nn.Conv2d(self.num_out_features, num_channel_out, kernel_size=1, stride=1, padding=0, groups=32, bias=False),
+ nn.BatchNorm2d(num_channel_out),
+ nn.ReLU(inplace=True),
+ )
+
+ if init_weights:
+ self._initialize_weights()
+
+ def forward(self, x):
+ x = self.features(x)
+ x = self.lastlayer(x)
+ return x
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.constant_(m.bias, 0)
+
+
+def make_layers(cfg, down_sample=8, batch_norm=False):
+ layers = []
+ in_channels = 3
+ for v in cfg:
+ if v == 'M':
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+ elif isinstance(v, dict):
+ cur_size = v[down_sample]
+ layers += [nn.MaxPool2d(kernel_size=cur_size, stride=cur_size)]
+ else:
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
+ if batch_norm:
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
+ else:
+ layers += [conv2d, nn.ReLU(inplace=True)]
+ in_channels = v
+ return nn.Sequential(*layers)
+
+
+cfgs = {
+ 'A': [64, 'M', 128, 'M', 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, {4: (2, 1), 8: (2, 1)}, 512, 512,
+ {4: (2, 1), 8: (2, 1)}],
+ 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, {4: (2, 1), 8: (2, 1)}, 512, 512,
+ {4: (2, 1), 8: (2, 1)}, ],
+ 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, 512, {4: (2, 1), 8: (2, 1)}, 512,
+ 512, 512, {4: (2, 1), 8: (2, 1)}, ],
+ 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, 512, 512,
+ {4: (2, 1), 8: (2, 1)}, 512, 512, 512, 512, {4: (2, 1), 8: (2, 1)}, ],
+}
+
+
+def _vgg(model_path, cfg, batch_norm, pretrained, progress, num_channel_out, down_sample, **kwargs):
+ if pretrained:
+ kwargs['init_weights'] = False
+ model = VGG(make_layers(cfgs[cfg], down_sample, batch_norm=batch_norm), num_channel_out, **kwargs)
+ if model_path and pretrained:
+ state_dict = torch.load(model_path)
+ model.load_state_dict(state_dict, strict=False)
+ return model
+
+
+def vgg11_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs):
+ return _vgg(model_path, 'A', True, pretrained, progress, num_channel_out, down_sample, **kwargs)
+
+
+def vgg13_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs):
+ return _vgg(model_path, 'B', True, pretrained, progress, num_channel_out, down_sample, **kwargs)
+
+
+def vgg16_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs):
+ return _vgg(model_path, 'D', True, pretrained, progress, num_channel_out, down_sample, **kwargs)
+
+
+def vgg19_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs):
+ return _vgg(model_path, 'E', True, pretrained, progress, num_channel_out, down_sample, **kwargs)
\ No newline at end of file
diff --git a/HybridViT/module/component/prediction_head/__init__.py b/HybridViT/module/component/prediction_head/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..48614426ef60537fb63eaf08cc105c003f79aea1
--- /dev/null
+++ b/HybridViT/module/component/prediction_head/__init__.py
@@ -0,0 +1,5 @@
+from .seq2seq import Attention
+from .seq2seq_v2 import AttentionV2
+from .tfm import TransformerPrediction
+
+__all__ = ['Attention', 'AttentionV2', 'TransformerPrediction']
\ No newline at end of file
diff --git a/HybridViT/module/component/prediction_head/addon_module/__init__.py b/HybridViT/module/component/prediction_head/addon_module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..37e749bbfd6d410ecac27a1f772961e572ed2ec1
--- /dev/null
+++ b/HybridViT/module/component/prediction_head/addon_module/__init__.py
@@ -0,0 +1,3 @@
+from .attention1D import *
+from .attention2D import *
+from .position_encoding import *
\ No newline at end of file
diff --git a/HybridViT/module/component/prediction_head/addon_module/attention1D.py b/HybridViT/module/component/prediction_head/addon_module/attention1D.py
new file mode 100644
index 0000000000000000000000000000000000000000..df74822445875c572604f0c62060a7e601b5b4a4
--- /dev/null
+++ b/HybridViT/module/component/prediction_head/addon_module/attention1D.py
@@ -0,0 +1,218 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+torch.autograd.set_detect_anomaly(True)
+
+class LuongAttention(nn.Module):
+ def __init__(self, input_size, hidden_size, num_embeddings, num_classes, method='dot'):
+ super(LuongAttention, self).__init__()
+ self.attn = LuongAttentionCell(hidden_size, method)
+ self.rnn = nn.LSTMCell(num_embeddings, hidden_size)
+ self.hidden_size = hidden_size
+ self.generator = nn.Linear(2*hidden_size, num_classes)
+
+ def forward(self, prev_hidden, batch_H, embed_text):
+ hidden = self.rnn(embed_text, prev_hidden)
+
+ e = self.attn(hidden[0], batch_H)
+ # print('Shape e', e.shape)
+ alpha = F.softmax(e, dim=1)
+ # print('Shape al', alpha.shape)
+
+ context = torch.bmm(alpha.unsqueeze(1), batch_H).squeeze(1) # batch_size x num_channel
+ output = torch.cat([context, hidden[0]], 1) # batch_size x (num_channel + num_embedding)
+ output = torch.tanh(output)
+ output = self.generator(output)
+
+ return output, hidden, alpha
+
+class LuongAttentionCell(nn.Module):
+ def __init__(self, hidden_size, method='dot'):
+ super(LuongAttentionCell, self).__init__()
+ self.method = method
+ self.hidden_size = hidden_size
+
+ # Defining the layers/weights required depending on alignment scoring method
+ if method == "general":
+ self.fc = nn.Linear(hidden_size, hidden_size, bias=False)
+
+ elif method == "concat":
+ self.fc = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size))
+
+ def forward(self, decoder_hidden, encoder_outputs):
+ decoder_hidden = decoder_hidden.unsqueeze(1)
+ # print('shape', decoder_hidden.shape)
+
+ if self.method == "dot":
+ # For the dot scoring method, no weights or linear layers are involved
+ return encoder_outputs.bmm(decoder_hidden.permute(0, 2, 1)).squeeze(-1)
+
+ elif self.method == "general":
+ # For general scoring, decoder hidden state is passed through linear layers to introduce a weight matrix
+ out = self.fc(decoder_hidden)
+ return encoder_outputs.bmm(out.permute(0, 2 , 1)).squeeze(-1)
+
+ elif self.method == "concat":
+ # For concat scoring, decoder hidden state and encoder outputs are concatenated first
+ out = torch.tanh(self.fc(decoder_hidden+encoder_outputs))
+ # print('Shape', out.shape)
+ return out.bmm(self.weight.unsqueeze(-1).repeat(out.shape[0], 1, 1)).squeeze(-1)
+
+class BahdanauAttentionCell(nn.Module):
+ def __init__(self, input_dim, hidden_dim):
+ super(BahdanauAttentionCell, self).__init__()
+ self.i2h = nn.Linear(input_dim, hidden_dim, bias=False)
+ self.h2h = nn.Linear(hidden_dim, hidden_dim)
+ self.score = nn.Linear(hidden_dim, 1, bias=False)
+
+ def forward(self, decoder_hidden, encoder_output):
+ encoder_proj = self.i2h(encoder_output)
+ hidden_proj = self.h2h(decoder_hidden[0]).unsqueeze(1)
+ score = self.score(torch.tanh(encoder_proj + hidden_proj))
+ return score
+
+class BahdanauAttention(nn.Module):
+ def __init__(self, input_size=100, hidden_size=256, num_embeddings=10, num_classes=10):
+ super(BahdanauAttention, self).__init__()
+ self.attn = BahdanauAttentionCell(input_size, hidden_size)
+ self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.generator = nn.Linear(hidden_size, num_classes)
+
+ def set_mem(self, prev_attn):
+ pass
+
+ def reset_mem(self):
+ pass
+
+ def forward(self, prev_hidden, batch_H, embed_text):
+ # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
+ e = self.attn(prev_hidden, batch_H)
+ alpha = F.softmax(e, dim=1)
+ context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
+ concat_context = torch.cat([context, embed_text], 1) # batch_size x (num_channel + num_embedding)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+ output = self.generator(cur_hidden[0])
+
+ return output, cur_hidden, alpha
+
+class ConstituentAttentionCell(nn.Module):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+
+class ConstituentCoverageAttentionCell(ConstituentAttentionCell):
+ pass
+
+class LocationAwareAttentionCell(nn.Module):
+ def __init__(self, kernel_size, kernel_dim, hidden_dim, input_dim):
+ super().__init__()
+ self.loc_conv = nn.Conv1d(1, kernel_dim, kernel_size=2*kernel_size+1, padding=kernel_size, bias=True)
+ self.loc_proj = nn.Linear(kernel_dim, hidden_dim)
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim)
+ self.key_proj = nn.Linear(input_dim, hidden_dim)
+ self.score = nn.Linear(hidden_dim, 1)
+
+ def forward(self, decoder_hidden, encoder_output, last_alignment):
+ batch_size, seq_length, hidden_dim = encoder_output.shape[0], encoder_output.shape[1], decoder_hidden[0].shape[1]
+
+ encoder_proj = self.key_proj(encoder_output)
+ hidden_proj = self.query_proj(decoder_hidden[0]).unsqueeze(1)
+
+ if last_alignment is None:
+ last_alignment = decoder_hidden[0].new_zeros(batch_size, seq_length, 1)
+
+ loc_context = self.loc_conv(last_alignment.permute(0, 2, 1))
+ loc_context = loc_context.transpose(1, 2)
+ loc_context = self.loc_proj(loc_context)
+
+ assert len(loc_context.shape) == 3
+ assert loc_context.shape[0] == batch_size, f'{loc_context.shape[0]}-{batch_size}'
+ assert loc_context.shape[1] == seq_length
+ assert loc_context.shape[2] == hidden_dim
+
+ score = self.score(torch.tanh(
+ encoder_proj
+ + hidden_proj
+ + loc_context
+ ))
+ return score
+
+class CoverageAttention(nn.Module):
+ def __init__(self, input_dim, hidden_dim, kernel_size, kernel_dim, temperature=1.0, smoothing=False):
+ super().__init__()
+ self.smoothing = smoothing
+ self.temperature = temperature
+ self.prev_attn = None
+ self.attn = LocationAwareAttentionCell(kernel_size, kernel_dim, hidden_dim, input_dim)
+
+ def set_mem(self, prev_attn):
+ self.prev_attn = prev_attn
+
+ def reset_mem(self):
+ self.prev_attn = None
+
+ def forward(self, prev_hidden, batch_H):
+ e = self.attn(prev_hidden, batch_H, self.prev_attn)
+
+ if self.smoothing:
+ e = F.sigmoid(e, dim=1)
+ alpha = e.div(e.sum(dim=-1).unsqueeze(dim=-1))
+ else:
+ e = e / self.temperature
+ alpha = F.softmax(e, dim=1)
+
+ context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
+
+ return context, alpha
+
+class LocationAwareAttention(BahdanauAttention):
+ def __init__(self, kernel_size, kernel_dim, temperature=1.0, smoothing=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.smoothing = smoothing
+ self.temperature = temperature
+ self.prev_attn = None
+ self.attn = LocationAwareAttentionCell(kernel_size, kernel_dim, self.hidden_size, self.input_size)
+
+ def set_mem(self, prev_attn):
+ self.prev_attn = prev_attn
+
+ def reset_mem(self):
+ self.prev_attn = None
+
+ def forward(self, prev_hidden, batch_H, embed_text):
+ e = self.attn(prev_hidden, batch_H, self.prev_attn)
+
+ if self.smoothing:
+ e = F.sigmoid(e, dim=1)
+ alpha = e.div(e.sum(dim=-1).unsqueeze(dim=-1))
+ else:
+ e = e / self.temperature
+ alpha = F.softmax(e, dim=1)
+
+ context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel , batch_H: batch_sizexseq_lengthxnum_channel, alpha:
+ concat_context = torch.cat([context, embed_text], 1) # batch_size x (num_channel + num_embedding)
+ cur_hidden = self.rnn(concat_context, prev_hidden)
+ output = self.generator(cur_hidden[0])
+
+ return output, cur_hidden, alpha
+
+
+# class MaskAttention(nn.Module):
+# def __init__(self):
+# super().__init__()
+
+# class CoverageAttentionCell(nn.Module):
+# def __init__(self, )
+
+# class CoverageAttention(nn.Module):
+# """
+# http://home.ustc.edu.cn/~xysszjs/paper/PR2017.pdf
+# """
+# def __init__(self, input_size, hidden_size, num_embedding):
+# super().__init__()
+
+# def forward(self, prev_hidden, batch_H, embed_text):
diff --git a/HybridViT/module/component/prediction_head/addon_module/attention2D.py b/HybridViT/module/component/prediction_head/addon_module/attention2D.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f5c4c005eb1bde8cee73516d61ff5a87bcdbbce
--- /dev/null
+++ b/HybridViT/module/component/prediction_head/addon_module/attention2D.py
@@ -0,0 +1,88 @@
+import torch
+import torch.nn as nn
+from .attention1D import LocationAwareAttention
+"""
+NOTE :
+"""
+class SARAttention(nn.Module):
+ def __init__(self,
+ input_size,
+ attention_size,
+ backbone_size,
+ output_size,
+
+ ):
+ self.conv1x1_1 = nn.Conv2d(output_size, attention_size, kernel_size=1, stride=1)
+ self.conv3x3 = nn.Conv2d(backbone_size, attention_size, kernel_size=3, stride=1, padding=1)
+ self.conv1x1_2 = nn.Conv2d(attention_size, 1, kernel_size=1, stride=1)
+
+ self.rnn_decoder_1 = nn.LSTMCell(input_size, input_size)
+ self.rnn_decoder_2 = nn.LSTMCell(input_size, input_size)
+
+
+ def forward(
+ self,
+ dec_input,
+ feature_map,
+ holistic_feature,
+ hidden_1,
+ cell_1,
+ hidden_2,
+ cell_2
+ ):
+ _, _, H_feat, W_feat = feature_map.size()
+ hidden_1, cell_1 = self.rnn_decoder_1(dec_input, (hidden_1, cell_1))
+ hidden_2, cell_2 = self.rnn_decoder_2(hidden_1, (hidden_2, cell_2))
+
+ hidden_2_tile = hidden_2.view(hidden_2.size(0), hidden_2.size(1), 1, 1)
+ attn_query = self.conv1x1_1(hidden_2_tile)
+ attn_query = attn_query.expand(-1, -1, H_feat, W_feat)
+
+ attn_key = self.conv3x3(feature_map)
+ attn_weight = torch.tanh(torch.add(attn_query, attn_key, alpha=1))
+ attn_weight = self.conv1x1_2(attn_weight) #shape B, 1, H, W
+
+ #TO DO: apply mask for attention weight
+
+
+class LocationAwareAttentionCell2D(nn.Module):
+ def __init__(self, kernel_size, kernel_dim, hidden_dim, input_dim):
+ super().__init__()
+ self.loc_conv = nn.Conv2d(1, kernel_dim, kernel_size=2*kernel_size+1, padding=kernel_size, bias=True)
+ self.loc_proj = nn.Linear(kernel_dim, hidden_dim)
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim)
+ self.key_proj = nn.Linear(input_dim, hidden_dim)
+ self.score = nn.Linear(hidden_dim, 1)
+
+ def forward(self, decoder_hidden, encoder_output, last_alignment):
+ batch_size, enc_h, enc_w, hidden_dim = encoder_output.shape[0], encoder_output.shape[1], encoder_output.shape[2], decoder_hidden[0].shape[1]
+
+ encoder_proj = self.key_proj(encoder_output)
+ hidden_proj = self.query_proj(decoder_hidden[0]).unsqueeze(1)
+
+ if last_alignment is None:
+ last_alignment = decoder_hidden[0].new_zeros(batch_size, enc_h, enc_w, 1)
+
+ loc_context = self.loc_conv(last_alignment.permute(0, 2, 1))
+ loc_context = loc_context.transpose(1, 2)
+ loc_context = self.loc_proj(loc_context)
+
+ assert len(loc_context.shape) == 3
+ assert loc_context.shape[0] == batch_size, f'{loc_context.shape[0]}-{batch_size}'
+ assert loc_context.shape[1] == enc_h
+ assert loc_context.shape[2] == enc_w
+ assert loc_context.shape[3] == hidden_dim
+
+ loc_context = loc_context.reshape(batch_size, enc_h*enc_w, hidden_dim)
+
+ score = self.score(torch.tanh(
+ encoder_proj
+ + hidden_proj
+ + loc_context
+ ))
+ return score
+
+class LocationAwareAttention2D(LocationAwareAttention):
+ def __init__(self, kernel_size, kernel_dim, temperature=1, smoothing=False, *args, **kwargs):
+ super().__init__(kernel_size, kernel_dim, temperature, smoothing, *args, **kwargs)
+ self.attn = LocationAwareAttentionCell2D(kernel_size, kernel_dim, self.hidden_size, self.input_size)
diff --git a/HybridViT/module/component/prediction_head/addon_module/position_encoding.py b/HybridViT/module/component/prediction_head/addon_module/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..cba525cc89544df367a1a784ffd94d9f941a068e
--- /dev/null
+++ b/HybridViT/module/component/prediction_head/addon_module/position_encoding.py
@@ -0,0 +1,27 @@
+import torch
+from torch import nn
+
+__all__ = ['WordPosEnc']
+
+class WordPosEnc(nn.Module):
+ def __init__(
+ self, d_model: int = 512, max_len: int = 500, temperature: float = 10000.0
+ ) -> None:
+ super().__init__()
+ pe = torch.zeros(max_len, d_model)
+
+ position = torch.arange(0, max_len, dtype=torch.float)
+ dim_t = torch.arange(0, d_model, 2, dtype=torch.float)
+ div_term = 1.0 / (temperature ** (dim_t / d_model))
+
+ inv_freq = torch.einsum("i, j -> i j", position, div_term)
+
+ pe[:, 0::2] = inv_freq.sin()
+ pe[:, 1::2] = inv_freq.cos()
+ self.register_buffer("pe", pe)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ _, seq_len, _ = x.size()
+ emb = self.pe[:seq_len, :]
+ x = x + emb[None, :, :]
+ return x
diff --git a/HybridViT/module/component/prediction_head/seq2seq.py b/HybridViT/module/component/prediction_head/seq2seq.py
new file mode 100644
index 0000000000000000000000000000000000000000..0920f5035bc50914b3414a6c78302be2b83fb1cd
--- /dev/null
+++ b/HybridViT/module/component/prediction_head/seq2seq.py
@@ -0,0 +1,268 @@
+import random
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import repeat
+from ...converter import AttnLabelConverter as ATTN
+from .addon_module import *
+
+class Attention(nn.Module):
+ def __init__(self,
+ kernel_size,
+ kernel_dim,
+ input_size,
+ hidden_size,
+ num_classes,
+ embed_dim=None,
+ attn_type='coverage',
+ embed_target=False,
+ enc_init=False, #init hidden state of decoder with enc output
+ teacher_forcing=1.0,
+ droprate=0.1,
+ method='concat',
+ seqmodel='ViT',
+ viz_attn: bool = False,
+ device='cuda'
+ ):
+ super(Attention, self).__init__()
+ if embed_dim is None: embed_dim = input_size
+ if embed_target:
+ self.embedding = nn.Embedding(num_classes, embed_dim, padding_idx=ATTN.START())
+
+ common = {
+ 'input_size': input_size,
+ 'hidden_size': hidden_size,
+ 'num_embeddings': embed_dim if embed_target else num_classes,
+ 'num_classes': num_classes
+ }
+
+ if attn_type == 'luong':
+ common['method'] = method
+ self.attention_cell = LuongAttention(**common)
+ elif attn_type == 'loc_aware':
+ self.attention_cell = LocationAwareAttention(kernel_size=kernel_size, kernel_dim=kernel_dim, **common)
+ elif attn_type == 'coverage':
+ self.attention_cell = LocationAwareAttention(kernel_size=kernel_size, kernel_dim=kernel_dim, **common)
+ else:
+ self.attention_cell = BahdanauAttention(**common)
+
+ self.dropout = nn.Dropout(droprate)
+ self.embed_target = embed_target
+ self.hidden_size = hidden_size
+ self.input_size = input_size
+ self.num_classes = num_classes
+ self.teacher_forcing = teacher_forcing
+ self.device = device
+ self.attn_type = attn_type
+ self.enc_init = enc_init
+ self.viz_attn = viz_attn
+ self.seqmodel = seqmodel
+
+ if enc_init: self.init_hidden()
+
+ def _embed_text(self, input_char):
+ return self.embedding(input_char)
+
+ def _char_to_onehot(self, input_char, onehot_dim=38):
+ input_char = input_char.unsqueeze(1)
+ batch_size = input_char.size(0)
+ one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(self.device)
+ one_hot = one_hot.scatter_(1, input_char, 1)
+ return one_hot
+
+ def init_hidden(self):
+ self.proj_init_h = nn.Linear(self.input_size, self.hidden_size, bias=True)
+ self.proj_init_c = nn.Linear(self.input_size, self.hidden_size, bias=True)
+
+ def forward_beam(
+ self,
+ batch_H: torch.Tensor,
+ batch_max_length=25,
+ beam_size=4,
+ ):
+ batch_size = batch_H.size(0)
+ assert batch_size == 1
+ num_steps = batch_max_length + 1
+ batch_H = batch_H.squeeze(dim=0)
+ batch_H = repeat(batch_H, "s e -> b s e", b = beam_size)
+
+ if self.enc_init:
+ if self.seqmodel == 'BiLSTM':
+ init_embedding = batch_H.mean(dim=1)
+ else:
+ init_embedding = batch_H[:, 0, :]
+ h_0 = self.proj_init_h(init_embedding)
+ c_0 = self.proj_init_c(init_embedding)
+ hidden = (h_0, c_0)
+ else:
+ hidden = (torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device),
+ torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device))
+
+ if self.attn_type == 'coverage':
+ alpha_cum = torch.zeros(beam_size, batch_H.shape[1], 1, dtype=torch.float32, device=self.device)
+ self.attention_cell.reset_mem()
+
+ k_prev_words = torch.LongTensor([[ATTN.START()]] * beam_size).to(self.device)
+ seqs = k_prev_words
+ targets = k_prev_words.squeeze(dim=-1)
+ top_k_scores = torch.zeros(beam_size, 1).to(self.device)
+
+ if self.viz_attn:
+ seqs_alpha = torch.ones(beam_size, 1, batch_H.shape[1]).to(self.device)
+
+ complete_seqs = list()
+ if self.viz_attn:
+ complete_seqs_alpha = list()
+ complete_seqs_scores = list()
+
+ for step in range(num_steps):
+ embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets)
+ output, hidden, alpha = self.attention_cell(hidden, batch_H, embed_text)
+ output = self.dropout(output)
+ vocab_size = output.shape[1]
+
+ scores = F.log_softmax(output, dim=-1)
+ scores = top_k_scores.expand_as(scores) + scores
+ if step == 0:
+ top_k_scores, top_k_words = scores[0].topk(beam_size, 0, True, True)
+ else:
+ top_k_scores, top_k_words = scores.view(-1).topk(beam_size, 0, True, True)
+
+ prev_word_inds = top_k_words // vocab_size
+ next_word_inds = top_k_words % vocab_size
+
+ seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)
+ if self.viz_attn:
+ seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].permute(0, 2, 1)],
+ dim=1)
+
+ incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
+ next_word != ATTN.END()]
+
+ complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
+
+ if len(complete_inds) > 0:
+ complete_seqs.extend(seqs[complete_inds].tolist())
+ if self.viz_attn:
+ complete_seqs_alpha.extend(seqs_alpha[complete_inds])
+ complete_seqs_scores.extend(top_k_scores[complete_inds])
+
+ beam_size = beam_size - len(complete_inds)
+ if beam_size == 0:
+ break
+
+ seqs = seqs[incomplete_inds]
+ if self.viz_attn:
+ seqs_alpha = seqs_alpha[incomplete_inds]
+ hidden = hidden[0][prev_word_inds[incomplete_inds]], \
+ hidden[1][prev_word_inds[incomplete_inds]]
+ batch_H = batch_H[prev_word_inds[incomplete_inds]]
+ top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
+ targets = next_word_inds[incomplete_inds]
+
+ if self.attn_type == 'coverage':
+ alpha_cum = alpha_cum + alpha
+ alpha_cum = alpha_cum[incomplete_inds]
+ self.attention_cell.set_mem(alpha_cum)
+ elif self.attn_type == 'loc_aware':
+ self.attention_cell.set_mem(alpha)
+
+ if len(complete_inds) == 0:
+ seq = seqs[0][1:].tolist()
+ seq = torch.LongTensor(seq).unsqueeze(0)
+ score = top_k_scores[0]
+ if self.viz_attn:
+ alphas = seqs_alpha[0][1:, ...]
+ return seq, score, alphas
+ else:
+ return seq, score, None
+ else:
+ combine_lst = tuple(zip(complete_seqs, complete_seqs_scores))
+ best_ind = combine_lst.index(max(combine_lst, key=lambda x: x[1] / len(x[0]))) #https://youtu.be/XXtpJxZBa2c?t=2407
+ seq = complete_seqs[best_ind][1:] #not include [GO] token
+ seq = torch.LongTensor(seq).unsqueeze(0)
+ score = max(complete_seqs_scores)
+
+ if self.viz_attn:
+ alphas = complete_seqs_alpha[best_ind][1:, ...]
+ return seq, score, alphas
+ else:
+ return seq, score, None
+
+ def forward_greedy(self, batch_H, text, is_train=True, is_test=False, batch_max_length=25):
+ batch_size = batch_H.size(0)
+ num_steps = batch_max_length + 1
+ if self.enc_init:
+ if self.seqmodel == 'BiLSTM':
+ init_embedding = batch_H.mean(dim=1)
+ encoder_hidden = batch_H
+ else:
+ encoder_hidden = batch_H
+ init_embedding = batch_H[:, 0, :]
+ h_0 = self.proj_init_h(init_embedding)
+ c_0 = self.proj_init_c(init_embedding)
+ hidden = (h_0, c_0)
+ else:
+ encoder_hidden = batch_H
+ hidden = (torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device),
+ torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device))
+
+ targets = torch.zeros(batch_size, dtype=torch.long, device=self.device) # [GO] token
+ probs = torch.zeros(batch_size, num_steps, self.num_classes, dtype=torch.float32, device=self.device)
+
+ if self.viz_attn:
+ self.alpha_stores = torch.zeros(batch_size, num_steps, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
+ if self.attn_type == 'coverage':
+ alpha_cum = torch.zeros(batch_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
+
+ self.attention_cell.reset_mem()
+
+ if is_test:
+ end_flag = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
+
+ for i in range(num_steps):
+ embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets)
+ output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text)
+ output = self.dropout(output)
+ if self.viz_attn:
+ self.alpha_stores[:, i] = alpha
+ if self.attn_type == 'coverage':
+ alpha_cum = alpha_cum + alpha
+ self.attention_cell.set_mem(alpha_cum)
+ elif self.attn_type == 'loc_aware':
+ self.attention_cell.set_mem(alpha)
+
+ probs_step = output
+ probs[:, i, :] = probs_step
+
+ if i == num_steps - 1:
+ break
+
+ if is_train:
+ if self.teacher_forcing < random.random():
+ _, next_input = probs_step.max(1)
+ targets = next_input
+ else:
+ targets = text[:, i+1]
+ else:
+ _, next_input = probs_step.max(1)
+ targets = next_input
+
+ if is_test:
+ end_flag = end_flag | (next_input == ATTN.END())
+ if end_flag.all():
+ break
+
+ _, preds_index = probs.max(2)
+
+ return preds_index, probs, None # batch_size x num_steps x num_classes
+
+ def forward(self, beam_size, batch_H, text, batch_max_length, is_train=True, is_test=False):
+ if is_train:
+ return self.forward_greedy(batch_H, text, is_train, is_test, batch_max_length)
+ else:
+ if beam_size > 1:
+ return self.forward_beam(batch_H, batch_max_length, beam_size)
+ else:
+ return self.forward_greedy(batch_H, text, is_train, is_test, batch_max_length)
+
diff --git a/HybridViT/module/component/prediction_head/seq2seq_v2.py b/HybridViT/module/component/prediction_head/seq2seq_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..20eec35c9de355720de2a5a1848f1a858b8485fd
--- /dev/null
+++ b/HybridViT/module/component/prediction_head/seq2seq_v2.py
@@ -0,0 +1,218 @@
+import random
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import repeat
+from ...converter import AttnLabelConverter as ATTN
+from .addon_module import *
+from .seq2seq import Attention
+
+
+class AttentionV2(Attention):
+ def forward_beam(
+ self,
+ batch_H: torch.Tensor,
+ batch_max_length=25,
+ beam_size=4,
+ ):
+ batch_size = batch_H.size(0)
+ assert batch_size == 1
+ num_steps = batch_max_length + 1
+ batch_H = batch_H.squeeze(dim=0)
+ batch_H = repeat(batch_H, "s e -> b s e", b = beam_size)
+
+ encoder_hidden = None
+ if self.seqmodel in ['BiLSTM', 'VIG']:
+ encoder_hidden = batch_H
+ elif self.seqmodel == 'TFM':
+ encoder_hidden = batch_H[:, 1:, :]
+ else:
+ raise ValueError('seqmodel must be either BiLSTM or TFM option')
+
+ if self.enc_init:
+ init_embedding = None
+ if self.seqmodel in ['BiLSTM', 'VIG']:
+ init_embedding = batch_H.mean(dim=1)
+ elif self.seqmodel == 'TFM':
+ init_embedding = batch_H[:, 0, :]
+ else:
+ raise ValueError('seqmodel must be either BiLSTM or TFM option')
+
+ assert init_embedding is not None
+ h_0 = self.proj_init_h(init_embedding)
+ c_0 = self.proj_init_c(init_embedding)
+ hidden = (h_0, c_0)
+ else:
+ hidden = (torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device),
+ torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device))
+
+ assert encoder_hidden is not None
+
+ if self.attn_type == 'coverage':
+ alpha_cum = torch.zeros(beam_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
+ self.attention_cell.reset_mem()
+
+ k_prev_words = torch.LongTensor([[ATTN.START()]] * beam_size).to(self.device)
+ seqs = k_prev_words
+ targets = k_prev_words.squeeze(dim=-1)
+ top_k_scores = torch.zeros(beam_size, 1).to(self.device)
+
+ if self.viz_attn:
+ seqs_alpha = torch.ones(beam_size, 1, encoder_hidden.shape[1]).to(self.device)
+
+ complete_seqs = list()
+ if self.viz_attn:
+ complete_seqs_alpha = list()
+ complete_seqs_scores = list()
+
+ for step in range(num_steps):
+ embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets)
+ output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text)
+ output = self.dropout(output)
+ vocab_size = output.shape[1]
+
+ scores = F.log_softmax(output, dim=-1)
+ scores = top_k_scores.expand_as(scores) + scores
+ if step == 0:
+ top_k_scores, top_k_words = scores[0].topk(beam_size, 0, True, True)
+ else:
+ top_k_scores, top_k_words = scores.view(-1).topk(beam_size, 0, True, True)
+
+ prev_word_inds = top_k_words // vocab_size
+ next_word_inds = top_k_words % vocab_size
+
+ seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)
+ if self.viz_attn:
+ seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].permute(0, 2, 1)],
+ dim=1)
+
+ incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
+ next_word != ATTN.END()]
+
+ complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
+
+ if len(complete_inds) > 0:
+ complete_seqs.extend(seqs[complete_inds].tolist())
+ if self.viz_attn:
+ complete_seqs_alpha.extend(seqs_alpha[complete_inds])
+ complete_seqs_scores.extend(top_k_scores[complete_inds])
+
+ beam_size = beam_size - len(complete_inds)
+ if beam_size == 0:
+ break
+
+ seqs = seqs[incomplete_inds]
+ if self.viz_attn:
+ seqs_alpha = seqs_alpha[incomplete_inds]
+ hidden = hidden[0][prev_word_inds[incomplete_inds]], \
+ hidden[1][prev_word_inds[incomplete_inds]]
+ encoder_hidden = encoder_hidden[prev_word_inds[incomplete_inds]]
+ top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
+ targets = next_word_inds[incomplete_inds]
+
+ if self.attn_type == 'coverage':
+ alpha_cum = alpha_cum + alpha
+ alpha_cum = alpha_cum[incomplete_inds]
+ self.attention_cell.set_mem(alpha_cum)
+ elif self.attn_type == 'loc_aware':
+ self.attention_cell.set_mem(alpha)
+
+ if len(complete_inds) == 0:
+ seq = seqs[0][1:].tolist()
+ seq = torch.LongTensor(seq).unsqueeze(0)
+ score = top_k_scores[0]
+ if self.viz_attn:
+ alphas = seqs_alpha[0][1:, ...]
+ return seq, score, alphas
+ else:
+ return seq, score, None
+ else:
+ combine_lst = tuple(zip(complete_seqs, complete_seqs_scores))
+ best_ind = combine_lst.index(max(combine_lst, key=lambda x: x[1] / len(x[0]))) #https://youtu.be/XXtpJxZBa2c?t=2407
+ seq = complete_seqs[best_ind][1:] #not include [GO] token
+ seq = torch.LongTensor(seq).unsqueeze(0)
+ score = max(complete_seqs_scores)
+
+ if self.viz_attn:
+ alphas = complete_seqs_alpha[best_ind][1:, ...]
+ return seq, score, alphas
+ else:
+ return seq, score, None
+
+ def forward_greedy(self, batch_H, text, is_train=True, is_test=False, batch_max_length=25):
+ batch_size = batch_H.size(0)
+ num_steps = batch_max_length + 1
+ encoder_hidden = None
+ if self.seqmodel in ['BiLSTM', 'VIG']:
+ encoder_hidden = batch_H
+ elif self.seqmodel == 'TFM':
+ encoder_hidden = batch_H[:, 1:, :]
+ else:
+ raise ValueError('seqmodel must be either BiLSTM or TFM option')
+
+ if self.enc_init:
+ init_embedding = None
+ if self.seqmodel in ['BiLSTM', 'VIG']:
+ init_embedding = batch_H.mean(dim=1)
+ elif self.seqmodel == 'TFM':
+ init_embedding = batch_H[:, 0, :]
+ else:
+ raise ValueError('seqmodel must be either BiLSTM or TFM option')
+ h_0 = self.proj_init_h(init_embedding)
+ c_0 = self.proj_init_c(init_embedding)
+ hidden = (h_0, c_0)
+ else:
+ hidden = (torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device),
+ torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device))
+
+ targets = torch.zeros(batch_size, dtype=torch.long, device=self.device) # [GO] token
+ probs = torch.zeros(batch_size, num_steps, self.num_classes, dtype=torch.float32, device=self.device)
+
+ assert encoder_hidden is not None
+
+ if self.viz_attn:
+ self.alpha_stores = torch.zeros(batch_size, num_steps, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
+ if self.attn_type == 'coverage':
+ alpha_cum = torch.zeros(batch_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
+
+ self.attention_cell.reset_mem()
+
+ if is_test:
+ end_flag = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
+
+ for i in range(num_steps):
+ embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets)
+ output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text)
+ output = self.dropout(output)
+ if self.viz_attn:
+ self.alpha_stores[:, i] = alpha
+ if self.attn_type == 'coverage':
+ alpha_cum = alpha_cum + alpha
+ self.attention_cell.set_mem(alpha_cum)
+ elif self.attn_type == 'loc_aware':
+ self.attention_cell.set_mem(alpha)
+
+ probs_step = output
+ probs[:, i, :] = probs_step
+
+ if i == num_steps - 1:
+ break
+
+ if is_train:
+ if self.teacher_forcing < random.random():
+ _, next_input = probs_step.max(1)
+ targets = next_input
+ else:
+ targets = text[:, i+1]
+ else:
+ _, next_input = probs_step.max(1)
+ targets = next_input
+
+ if is_test:
+ end_flag = end_flag | (next_input == ATTN.END())
+ if end_flag.all():
+ break
+
+ _, preds_index = probs.max(2)
+
+ return preds_index, probs, None # batch_size x num_steps x num_classes
diff --git a/HybridViT/module/component/prediction_head/tfm.py b/HybridViT/module/component/prediction_head/tfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..e62a5f1d7d4991d61a3ad5d74f16a2ad4e257478
--- /dev/null
+++ b/HybridViT/module/component/prediction_head/tfm.py
@@ -0,0 +1,207 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from torch import FloatTensor, LongTensor
+from .addon_module import WordPosEnc
+from ...converter.tfm_converter import TFMLabelConverter as TFM
+from ....beam import Beam
+
+def _build_transformer_decoder(
+ d_model: int,
+ nhead: int,
+ num_decoder_layers: int,
+ dim_feedforward: int,
+ dropout: float,
+) -> nn.TransformerDecoder:
+ decoder_layer = nn.TransformerDecoderLayer(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ )
+
+ decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers)
+
+ for p in decoder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ return decoder
+
+
+class TransformerPrediction(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ num_decoder_layers: int,
+ dim_feedforward: int,
+ dropout: float,
+ num_classes: int,
+ max_seq_len: int,
+ padding_idx: int,
+ device: str = 'cuda:1'
+ ):
+ super().__init__()
+ self.max_seq_len = max_seq_len
+ self.padding_idx = padding_idx
+ self.num_classes = num_classes
+ self.device = device
+ self.word_embed = nn.Embedding(
+ num_classes, d_model, padding_idx=padding_idx
+ )
+
+ self.pos_enc = WordPosEnc(d_model=d_model)
+ self.d_model = d_model
+ self.model = _build_transformer_decoder(
+ d_model=d_model,
+ nhead=nhead,
+ num_decoder_layers=num_decoder_layers,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ )
+
+ self.proj = nn.Linear(d_model, num_classes)
+ self.beam = Beam(
+ ignore_w=TFM.PAD(),
+ start_w=TFM.START(),
+ stop_w=TFM.END(),
+ max_len=self.max_seq_len,
+ device=self.device
+ )
+
+ def reset_beam(self):
+ self.beam = Beam(
+ ignore_w=TFM.PAD(),
+ start_w=TFM.START(),
+ stop_w=TFM.END(),
+ max_len=self.max_seq_len,
+ device=self.device
+ )
+
+ def _build_attention_mask(self, length):
+ mask = torch.full(
+ (length, length),
+ fill_value=1,
+ dtype=torch.bool,
+ device=self.device
+ )
+ mask = torch.triu(mask).transpose(0, 1)
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
+ return mask
+
+ def _embedd_tgt(self, tgt: LongTensor, tgt_len: int):
+ tgt_mask = self._build_attention_mask(tgt_len)
+ if self.training:
+ tgt_pad_mask = tgt == self.padding_idx
+ else: tgt_pad_mask = None
+ tgt = self.word_embed(tgt)
+ tgt = self.pos_enc(tgt*math.sqrt(self.d_model))
+ return tgt, tgt_mask, tgt_pad_mask
+
+ def forward_greedy(
+ self, src: FloatTensor, tgt: LongTensor, output_weight: bool = False, is_test: bool = False
+ ) -> FloatTensor:
+ if self.training:
+ _, l = tgt.size()
+ tgt, tgt_mask, tgt_pad_mask = self._embedd_tgt(tgt, l)
+
+ src = rearrange(src, "b t d -> t b d")
+ tgt = rearrange(tgt, "b l d -> l b d")
+
+ out = self.model(
+ tgt=tgt,
+ memory=src,
+ tgt_mask=tgt_mask,
+ tgt_key_padding_mask=tgt_pad_mask
+ )
+
+ out = rearrange(out, "l b d -> b l d")
+ out = self.proj(out)
+ else:
+ out = None
+ src = rearrange(src, "b t d -> t b d")
+
+ end_flag = torch.zeros(src.shape[0], dtype=torch.bool, device=self.device)
+
+ for step in range(self.max_seq_len+1):
+ b, l = tgt.size()
+ emb_tgt, tgt_mask, tgt_pad_mask = self._embedd_tgt(tgt, l)
+ emb_tgt = rearrange(emb_tgt, "b l d -> l b d")
+
+ out = self.model(
+ tgt=emb_tgt,
+ memory=src,
+ tgt_mask=tgt_mask
+ )
+
+ out = rearrange(out, "l b d -> b l d")
+ out = self.proj(out)
+ probs = F.softmax(out, dim=-1)
+ next_text = torch.argmax(probs[:, -1:, :], dim=-1)
+ tgt = torch.cat([tgt, next_text], dim=-1)
+
+ end_flag = end_flag | (next_text == TFM.END())
+ if end_flag.all() and is_test:
+ break
+
+ _, preds_index = out.max(dim=2)
+ return preds_index, out
+
+ def forward_beam(self,
+ src: torch.FloatTensor,
+ beam_size: int
+ ):
+ assert src.size(0) == 1, f'beam search should only have signle source, encounter with batch size: {src.size(0)}'
+ out = None
+ src = src.squeeze(0)
+
+ for step in range(self.max_seq_len + 1):
+ hypotheses = self.beam.hypotheses
+ hyp_num = hypotheses.size(0)
+ l = hypotheses.size(1)
+ assert hyp_num <= beam_size, f"hyp_num: {hyp_num}, beam_size: {beam_size}"
+
+ emb_tgt = self.word_embed(hypotheses)
+ emb_tgt = self.pos_enc(emb_tgt*math.sqrt(self.d_model))
+ tgt_mask = self._build_attention_mask(l)
+ emb_tgt = rearrange(emb_tgt, "b l d -> l b d")
+
+ exp_src = repeat(src.squeeze(1), "s e -> s b e", b=hyp_num)
+
+ out = self.model(
+ tgt=emb_tgt,
+ memory=exp_src,
+ tgt_mask=tgt_mask
+ )
+
+ out = rearrange(out, "l b d -> b l d")
+ out = self.proj(out)
+ log_prob = F.log_softmax(out[:, step, :], dim=-1)
+ new_hypotheses, new_hyp_scores = self.beam.advance(log_prob, step, beam_size=beam_size)
+
+ if self.beam.done(beam_size):
+ break
+
+ self.beam.set_current_state(new_hypotheses)
+ self.beam.set_current_score(new_hyp_scores)
+
+ self.beam.set_hypothesis()
+ best_hyp = max(self.beam.completed_hypotheses, key=lambda h: h.score / len(h))
+ output = best_hyp.seq
+ output = torch.LongTensor(output).unsqueeze(0)
+ score = best_hyp.score
+
+ return output, score
+
+ def forward(self, beam_size, batch_H, text, is_test):
+ if self.training:
+ return self.forward_greedy(batch_H, text)
+ else:
+ if beam_size > 1:
+ return self.forward_beam(batch_H, beam_size)
+ else:
+ return self.forward_greedy(batch_H, text, is_test = is_test)
+
diff --git a/HybridViT/module/component/seq_modeling/__init__.py b/HybridViT/module/component/seq_modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f86fbf5890ae0ec641b6dbb61dbfdbfd1c4c98f
--- /dev/null
+++ b/HybridViT/module/component/seq_modeling/__init__.py
@@ -0,0 +1,2 @@
+from .bilstm import *
+from .vit_encoder import *
diff --git a/HybridViT/module/component/seq_modeling/addon_module/__init__.py b/HybridViT/module/component/seq_modeling/addon_module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2891afee2178461b0a6b62ba12544dc0222b127
--- /dev/null
+++ b/HybridViT/module/component/seq_modeling/addon_module/__init__.py
@@ -0,0 +1 @@
+from .patchembed import *
diff --git a/HybridViT/module/component/seq_modeling/addon_module/patchembed.py b/HybridViT/module/component/seq_modeling/addon_module/patchembed.py
new file mode 100644
index 0000000000000000000000000000000000000000..c35148e12b6522c3ce7de660da4f0d77d501234f
--- /dev/null
+++ b/HybridViT/module/component/seq_modeling/addon_module/patchembed.py
@@ -0,0 +1,161 @@
+import math
+import torch.nn as nn
+import torch
+from torch.nn import functional as F
+from timm.models.layers.helpers import to_2tuple
+from typing import Tuple, Union, List
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=(224, 224), patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ assert isinstance(img_size, tuple)
+ patch_size = to_2tuple(patch_size)
+ div_h, mod_h = divmod(img_size[0], patch_size[0])
+ div_w, mod_w = divmod(img_size[1], patch_size[1])
+ self.img_size = (patch_size[0]*(div_h + (1 if mod_h > 0 else 0)), \
+ patch_size[1]*(div_w + (1 if mod_w > 0 else 0)))
+ self.grid_size = (self.img_size[0] // patch_size[0], self.img_size[1] // patch_size[1])
+ self.patch_size = patch_size
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ _, _, H, W = x.shape
+ div_h, mod_h = divmod(H, self.patch_size[0])
+ div_w, mod_w = divmod(W, self.patch_size[1])
+ pad_H =self.patch_size[0]*(div_h + (1 if mod_h > 0 else 0)) - H
+ pad_W = self.patch_size[1]*(div_w + (1 if mod_w > 0 else 0)) - W
+ x = F.pad(x, (0, pad_W, 0 , pad_H))
+ assert x.shape[2] % self.patch_size[0] == 0 and x.shape[3] % self.patch_size[1] == 0
+ proj_x = self.proj(x).flatten(2).transpose(1, 2)
+ return proj_x, {'height': x.shape[2], 'width': x.shape[3]}, (x.shape[2] != self.img_size[0] or x.shape[3] != self.img_size[1])
+
+class HybridEmbed(nn.Module):
+ """ CNN Feature Map Embedding
+ Extract feature map from CNN, flatten, project to embedding dim.
+ """
+ def __init__(self, backbone, img_size: Tuple[int], patch_size=Union[List, int], feature_size=None, in_chans=3, embed_dim=768):
+ super().__init__()
+ assert isinstance(backbone, nn.Module)
+ if isinstance(patch_size, int):
+ patch_size = to_2tuple(patch_size)
+ else:
+ patch_size = tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.backbone = backbone
+ if feature_size is None:
+ with torch.no_grad():
+ # NOTE Most reliable way of determining output dims is to run forward pass
+ training = backbone.training
+ if training:
+ backbone.eval()
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
+ if isinstance(o, (list, tuple)):
+ o = o[-1] # last feature if backbone outputs list/tuple of features
+ feature_size = o.shape[-2:]
+ feature_dim = o.shape[1]
+ backbone.train(training)
+ else:
+ feature_size = to_2tuple(feature_size)
+ if hasattr(self.backbone, 'feature_info'):
+ feature_dim = self.backbone.feature_info.channels()[-1]
+ else:
+ feature_dim = self.backbone.num_features
+
+ assert feature_size[0] >= patch_size[0] and feature_size[1] >= patch_size[1]
+
+ div_h, mod_h = divmod(feature_size[0], patch_size[0])
+ div_w, mod_w = divmod(feature_size[1], patch_size[1])
+
+ self.feature_size = (patch_size[0]*(div_h + (1 if mod_h > 0 else 0)), patch_size[1]*(div_w + (1 if mod_w > 0 else 0)))
+ assert self.feature_size[0] % patch_size[0] == 0 and self.feature_size[1] % patch_size[1] == 0
+ self.grid_size = (self.feature_size[0] // patch_size[0], self.feature_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ origin_size = x.shape[-2:]
+ x = self.backbone(x)
+ f_h, f_w = x.shape[2:]
+ # assert f_h >= self.patch_size[0] and f_w >= self.patch_size[1]
+
+ div_h, mod_h = divmod(f_h, self.patch_size[0])
+ div_w, mod_w = divmod(f_w, self.patch_size[1])
+
+ pad_H =self.patch_size[0]*(div_h + (1 if mod_h > 0 else 0)) - f_h
+ pad_W = self.patch_size[1]*(div_w + (1 if mod_w > 0 else 0)) - f_w
+ x = F.pad(x, (0, pad_W, 0 , pad_H))
+
+ assert x.shape[2] % self.patch_size[0] == 0 and x.shape[3] % self.patch_size[1] == 0
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ proj_x = self.proj(x).flatten(2).transpose(1, 2)
+ return proj_x, (pad_W, pad_H), {'height': x.shape[2], 'width': x.shape[3]}, (x.shape[2] != self.feature_size[0] or x.shape[3] != self.feature_size[1])
+
+class HybridEmbed1D(nn.Module):
+ """ CNN Feature Map Embedding which using 1D embed patching
+ from https://arxiv.org/pdf/2111.08314.pdf, which benefits for text recognition task.Check paper for more detail
+ Extract feature map from CNN, flatten, project to embedding dim.
+ """
+ def __init__(self, backbone, img_size: Tuple[int], feature_size=None, patch_size=1, in_chans=3, embed_dim=768):
+ super().__init__()
+ assert isinstance(backbone, nn.Module)
+ self.img_size = img_size
+ self.backbone = backbone
+ self.embed_dim = embed_dim
+ if feature_size is None:
+ with torch.no_grad():
+ # NOTE Most reliable way of determining output dims is to run forward pass
+ training = backbone.training
+ if training:
+ backbone.eval()
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
+ if isinstance(o, (list, tuple)):
+ o = o[-1] # last feature if backbone outputs list/tuple of features
+ feature_size = o.shape[-2:]
+ feature_dim = o.shape[1]
+ backbone.train(training)
+ else:
+ feature_size = to_2tuple(feature_size)
+ if hasattr(self.backbone, 'feature_info'):
+ feature_dim = self.backbone.feature_info.channels()[-1]
+ else:
+ feature_dim = self.backbone.num_features
+
+ self.window_width = patch_size
+ assert feature_size[1] >= self.window_width
+ div_w, mod_w = divmod(feature_size[1], self.window_width)
+ self.feature_size = (feature_size[0], self.window_width*(div_w + (1 if mod_w > 0 else 0)))
+ assert self.feature_size[1] % self.window_width == 0
+ self.grid_size = (1, self.feature_size[1] // self.window_width)
+ self.num_patches = self.grid_size[1]
+ self.proj = nn.Conv1d(feature_dim, embed_dim, kernel_size=self.window_width, stride=self.window_width, bias=True)
+
+ def forward(self, x):
+ batch_size = x.shape[0]
+ x = self.backbone(x)
+ f_h, f_w = x.shape[2:]
+ assert f_w >= self.window_width
+
+ div_w, mod_w = divmod(f_w, self.window_width)
+ pad_W = self.window_width*(div_w + (1 if mod_w > 0 else 0)) - f_w
+
+ x = F.pad(x, (0, pad_W))
+ assert x.shape[3] % self.window_width == 0
+
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ proj_x = torch.zeros(batch_size, self.embed_dim, f_h, x.shape[3]//self.window_width, device=x.device, dtype=x.dtype)
+
+ for i in range(f_h):
+ proj = self.proj(x[:, :, i, :])
+ proj_x[:, :, i, :] = proj
+
+ proj_x = proj_x.mean(dim=2).transpose(1, 2) #BCHW->BCW
+
+ return proj_x, (pad_W, ), {'height': x.shape[2], 'width': x.shape[3]}, (x.shape[2] != self.feature_size[0] or x.shape[3] != self.feature_size[1])
\ No newline at end of file
diff --git a/HybridViT/module/component/seq_modeling/bilstm.py b/HybridViT/module/component/seq_modeling/bilstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d4e5b360b2b41435a8b3c78737acfd0d14bb398
--- /dev/null
+++ b/HybridViT/module/component/seq_modeling/bilstm.py
@@ -0,0 +1,33 @@
+import torch.nn as nn
+
+__all__ = ['BiLSTM_Seq_Modeling', 'BidirectionalLSTM']
+
+class BidirectionalLSTM(nn.Module):
+ def __init__(self, input_size, hidden_size, output_size):
+ super(BidirectionalLSTM, self).__init__()
+ self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
+ self.linear = nn.Linear(hidden_size * 2, output_size)
+
+ def forward(self, input):
+ """
+ input : visual feature [batch_size x T x input_size]
+ output : contextual feature [batch_size x T x output_size]
+ """
+ self.rnn.flatten_parameters()
+ recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
+ output = self.linear(recurrent) # batch_size x T x output_size
+ return output
+
+class BiLSTM_Seq_Modeling(nn.Module):
+ def __init__(self, num_layers, input_size, hidden_size, output_size):
+ super(BiLSTM_Seq_Modeling, self).__init__()
+ self.num_layers = num_layers
+ layers = []
+ layers += [BidirectionalLSTM(input_size, hidden_size, hidden_size)]
+ for i in range(num_layers-2):
+ layers.append(BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
+ layers.append(BidirectionalLSTM(hidden_size, hidden_size, output_size))
+ self.lstm = nn.Sequential(*layers)
+
+ def forward(self, input):
+ return self.lstm(input)
\ No newline at end of file
diff --git a/HybridViT/module/component/seq_modeling/vit/utils.py b/HybridViT/module/component/seq_modeling/vit/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..842796b84f9418c8b6d60c51f09a219ef7347c70
--- /dev/null
+++ b/HybridViT/module/component/seq_modeling/vit/utils.py
@@ -0,0 +1,59 @@
+import torch
+import math
+import warnings
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/HybridViT/module/component/seq_modeling/vit/vision_transformer.py b/HybridViT/module/component/seq_modeling/vit/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed828b0f098edfc7f811076831676d35eb6eada1
--- /dev/null
+++ b/HybridViT/module/component/seq_modeling/vit/vision_transformer.py
@@ -0,0 +1,184 @@
+import torch
+from torch import nn
+from functools import partial
+from collections import OrderedDict
+from ...common import DropPath
+from .utils import trunc_normal_
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class ConvFFN(nn.Module):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
+ https://arxiv.org/abs/2010.11929
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
+ norm_layer: (nn.Module): normalization layer
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+
+ self.patch_embed = None
+ num_patches = getattr(self.patch_embed, 'num_patches', 128)
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ # Representation layer
+ if representation_size:
+ self.num_features = representation_size
+ self.pre_logits = nn.Sequential(OrderedDict([
+ ('fc', nn.Linear(embed_dim, representation_size)),
+ ('act', nn.Tanh())
+ ]))
+ else:
+ self.pre_logits = nn.Identity()
+
+ # Classifier head
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)[:, 0]
+ x = self.pre_logits(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
diff --git a/HybridViT/module/component/seq_modeling/vit_encoder.py b/HybridViT/module/component/seq_modeling/vit_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..36f53ce4864e5a358f6bd1a30a705267ae5a58fc
--- /dev/null
+++ b/HybridViT/module/component/seq_modeling/vit_encoder.py
@@ -0,0 +1,276 @@
+import torch.nn as nn
+import torch
+from torch.nn import functional as F
+from .vit.utils import trunc_normal_
+from .vit.vision_transformer import VisionTransformer
+from ..feature_extractor.clova_impl import ResNet_FeatureExtractor
+from .addon_module import *
+from ..common.mae_posembed import get_2d_sincos_pos_embed
+
+__all__ = ['ViTEncoder', 'ViTEncoderV2', 'ViTEncoderV3', 'TRIGBaseEncoder', 'create_vit_modeling']
+
+class ViTEncoder(VisionTransformer):
+ '''
+ '''
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ if kwargs['hybrid_backbone'] is None:
+ self.patch_embed = PatchEmbed(
+ img_size=kwargs['img_size'],
+ in_chans=kwargs['in_chans'],
+ patch_size=kwargs['patch_size'],
+ embed_dim=kwargs['embed_dim'],
+ )
+ else:
+ self.patch_embed = HybridEmbed(
+ backbone=kwargs['hybrid_backbone'],
+ img_size=kwargs['img_size'],
+ in_chans=kwargs['in_chans'],
+ patch_size=kwargs['patch_size'],
+ embed_dim=kwargs['embed_dim'],
+ )
+ num_patches = self.patch_embed.num_patches
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']))
+ self.emb_height = self.patch_embed.grid_size[0]
+ self.emb_width = self.patch_embed.grid_size[1]
+ trunc_normal_(self.pos_embed, std=.02)
+ self.apply(self._init_weights)
+
+ def reset_classifier(self, num_classes):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def interpolating_pos_embedding(self, embedding, height, width):
+ """
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+ npatch = embedding.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and height == width:
+ return self.pos_embed
+
+ class_pos_embedding = self.pos_embed[:, 0]
+ patch_pos_embedding = self.pos_embed[:, 1:]
+ dim = self.pos_embed.shape[-1]
+
+ h0 = height // self.patch_embed.patch_size[0]
+ w0 = width // self.patch_embed.patch_size[1]
+ #add a small number to avo_id floating point error
+ # https://github.com/facebookresearch/dino/issues/8
+
+ h0 = h0 + 0.1
+ w0 = w0 + 0.1
+
+ patch_pos_embedding = nn.functional.interpolate(
+ patch_pos_embedding.reshape(1, self.emb_height, self.emb_width, dim).permute(0, 3, 1, 2),
+ scale_factor=(h0 / self.emb_height, w0 / self.emb_width),
+ mode='bicubic',
+ align_corners=False
+ )
+ assert int(h0) == patch_pos_embedding.shape[-2] and int(w0) == patch_pos_embedding.shape[-1]
+ patch_pos_embedding = patch_pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
+ class_pos_embedding = class_pos_embedding.unsqueeze(0)
+
+ return torch.cat((class_pos_embedding, patch_pos_embedding), dim=1)
+
+ def forward_features(self, x):
+ B, C, _, _ = x.shape
+
+ x, pad_info, size, interpolating_pos = self.patch_embed(x)
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if interpolating_pos:
+ x = x + self.interpolating_pos_embedding(x, size['height'], size['width'])
+ else:
+ x = x + self.pos_embed
+
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x, pad_info, size
+
+
+class TRIGBaseEncoder(ViTEncoder):
+ '''
+ https://arxiv.org/pdf/2111.08314.pdf
+ '''
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.patch_embed = HybridEmbed1D(
+ backbone=kwargs['hybrid_backbone'],
+ img_size=kwargs['img_size'],
+ in_chans=kwargs['in_chans'],
+ patch_size=kwargs['patch_size'],
+ embed_dim=kwargs['embed_dim'],
+ )
+ num_patches = self.patch_embed.num_patches
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']))
+ self.emb_height = 1
+ self.emb_width = self.patch_embed.grid_size[1]
+ trunc_normal_(self.pos_embed, std=.02)
+ self.apply(self._init_weights)
+
+ def interpolating_pos_embedding(self, embedding, height, width):
+ """
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+ npatch = embedding.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and height == width:
+ return self.pos_embed
+
+ class_pos_embedding = self.pos_embed[:, 0]
+ patch_pos_embedding = self.pos_embed[:, 1:]
+ dim = self.pos_embed.shape[-1]
+
+ w0 = width // self.patch_embed.window_width
+
+ #add a small number to avoid floating point error
+ # https://github.com/facebookresearch/dino/issues/8
+
+ w0 = w0 + 0.1
+
+ patch_pos_embedding = nn.functional.interpolate(
+ patch_pos_embedding.reshape(1, self.emb_height, self.emb_width, dim).permute(0, 3, 1, 2),
+ scale_factor=(1, w0 / self.emb_width),
+ mode='bicubic',
+ align_corners=False
+ )
+
+ assert int(w0) == patch_pos_embedding.shape[-1]
+ patch_pos_embedding = patch_pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
+ class_pos_embedding = class_pos_embedding.unsqueeze(0)
+
+ return torch.cat((class_pos_embedding, patch_pos_embedding), dim=1)
+
+ def forward_features(self, x):
+ B, _, _, _ = x.shape
+ x, padinfo, size, interpolating_pos = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+
+ x = torch.cat((cls_tokens, x), dim=1) #cls_tokens is init_embedding in TRIG paper
+
+ if interpolating_pos:
+ x = x + self.interpolating_pos_embedding(x, size['height'], size['width'])
+ else:
+ x = x + self.pos_embed
+
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x, padinfo, size
+
+
+class ViTEncoderV2(ViTEncoder):
+ def forward(self, x):
+ B, _, _, _ = x.shape
+
+ x, pad_info, size, _ = self.patch_embed(x)
+ _, numpatches, *_ = x.shape
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + self.pos_embed[:, :(numpatches + 1)]
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x, pad_info, size
+
+class ViTEncoderV3(ViTEncoder):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if hasattr(self, 'pos_embed'):
+ del self.pos_embed
+ num_patches = self.patch_embed.num_patches
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']), requires_grad=False)
+ self.initialize_posembed()
+
+ def initialize_posembed(self):
+ pos_embed = get_2d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ self.patch_embed.grid_size[0],
+ self.patch_embed.grid_size[1],
+ cls_token=True
+ )
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ def forward(self, x):
+ B, _, _, _ = x.shape
+
+ x, pad_info, size, _ = self.patch_embed(x)
+ _, numpatches, *_ = x.shape
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + self.pos_embed[:, :(numpatches + 1)]
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x, pad_info, size
+
+def create_vit_modeling(opt):
+ seq_modeling = opt['SequenceModeling']['params']
+ if seq_modeling['backbone'] is not None:
+ if seq_modeling['backbone']['name'] == 'resnet':
+ param_kwargs = dict()
+ if seq_modeling['backbone'].get('pretrained', None) is not None:
+ param_kwargs['pretrained'] = seq_modeling['backbone']['pretrained']
+ if seq_modeling['backbone'].get('weight_dir', None) is not None:
+ param_kwargs['weight_dir'] = seq_modeling['backbone']['weight_dir']
+ print('kwargs', param_kwargs)
+
+ backbone = ResNet_FeatureExtractor(
+ seq_modeling['backbone']['input_channel'],
+ seq_modeling['backbone']['output_channel'],
+ seq_modeling['backbone']['gcb'],
+ **param_kwargs
+ )
+ elif seq_modeling['backbone']['name'] == 'cnn':
+ backbone = None
+ else: backbone = None
+ max_dimension = (opt['imgH'], opt['max_dimension'][1]) if opt['imgH'] else opt['max_dimension']
+ if seq_modeling['patching_style'] == '2d':
+ if seq_modeling.get('fix_embed', False):
+ encoder = ViTEncoderV3
+ else:
+ if not seq_modeling.get('interpolate_embed', True):
+ encoder = ViTEncoderV2
+ else:
+ encoder = ViTEncoder
+ else:
+ encoder = TRIGBaseEncoder
+
+ encoder_seq_modeling = encoder(
+ img_size=max_dimension,
+ patch_size=seq_modeling['patch_size'],
+ in_chans=seq_modeling['input_channel'],
+ depth=seq_modeling['depth'],
+ num_classes=0,
+ embed_dim=seq_modeling['hidden_size'],
+ num_heads=seq_modeling['num_heads'],
+ hybrid_backbone=backbone
+ )
+
+ return encoder_seq_modeling
diff --git a/HybridViT/module/converter/__init__.py b/HybridViT/module/converter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..00fc5ab8375cbb78fdca2e9b6a1eda0af3de1de3
--- /dev/null
+++ b/HybridViT/module/converter/__init__.py
@@ -0,0 +1,3 @@
+from .builder import create_converter
+from .attn_converter import AttnLabelConverter
+from .tfm_converter import TFMLabelConverter
\ No newline at end of file
diff --git a/HybridViT/module/converter/attn_converter.py b/HybridViT/module/converter/attn_converter.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e00de083ad4f790fd8e4fe0c32308777c5d2d9c
--- /dev/null
+++ b/HybridViT/module/converter/attn_converter.py
@@ -0,0 +1,71 @@
+import torch
+import numpy as np
+
+class AttnLabelConverter(object):
+ """ Convert between text-label and text-index """
+
+ list_token = ['[GO]', '[s]', '[UNK]']
+ def __init__(self, character, device):
+ list_character = character
+ self.character = AttnLabelConverter.list_token + list_character
+
+ self.device = device
+ self.dict = {}
+ for i, char in enumerate(self.character):
+ self.dict[char] = i
+ self.ignore_idx = self.dict['[GO]']
+
+ @staticmethod
+ def START() -> int:
+ return AttnLabelConverter.list_token.index('[GO]')
+
+ @staticmethod
+ def END() -> int:
+ return AttnLabelConverter.list_token.index('[s]')
+
+ @staticmethod
+ def UNK() -> int:
+ return AttnLabelConverter.list_token.index('[UNK]')
+
+ def encode(self, text, batch_max_length=25):
+ length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
+ # batch_max_length = max(length) # this is not allowed for multi-gpu setting
+ batch_max_length += 1
+ # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
+ batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
+ for i, t in enumerate(text):
+ text = list(t)
+
+ if len(text) > batch_max_length:
+ text = text[:(batch_max_length-1)]
+
+ text.append('[s]')
+ text = [self.dict[char] if char in self.dict else self.dict['[UNK]'] for char in text]
+
+ batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
+ return (batch_text.to(self.device), torch.IntTensor(length).to(self.device))
+
+ def decode(self, text_index, token_level='word'):
+ """ convert text-index into text-label. """
+ texts = []
+ batch_size = text_index.shape[0]
+ for index in range(batch_size):
+ if token_level == 'word':
+ text = ' '.join([self.character[i] for i in text_index[index, :]])
+ else:
+ text = ''.join([self.character[i] for i in text_index[index, :]])
+ texts.append(text)
+ return texts
+
+ def detokenize(self, token_ids):
+ """convert token ids to list of token"""
+ b_toks = []
+ for tok in token_ids:
+ toks = []
+ for i in tok:
+ if self.character[i] == '[s]':
+ break
+ toks.append(self.character[i])
+ b_toks.append(toks)
+
+ return b_toks
\ No newline at end of file
diff --git a/HybridViT/module/converter/builder.py b/HybridViT/module/converter/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..0466749d19d75576d1bde83523c3293f77e6e8d0
--- /dev/null
+++ b/HybridViT/module/converter/builder.py
@@ -0,0 +1,6 @@
+from .attn_converter import AttnLabelConverter
+
+def create_converter(config, device):
+ if 'Attn' in config['Prediction']['name']:
+ converter = AttnLabelConverter(config['character'], device)
+ return converter
\ No newline at end of file
diff --git a/HybridViT/module/converter/tfm_converter.py b/HybridViT/module/converter/tfm_converter.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f18575f887608fd39cd0d67d483158111fb3ef9
--- /dev/null
+++ b/HybridViT/module/converter/tfm_converter.py
@@ -0,0 +1,90 @@
+import torch
+import numpy as np
+
+
+class TFMLabelConverter(object):
+ """ Convert between text-label and text-index """
+
+ list_token = ['[PAD]', '[GO]', '[s]', '[UNK]']
+ def __init__(self, character, device):
+ list_character = character
+ self.character = TFMLabelConverter.list_token + list_character
+
+ self.device = device
+ self.dict = {}
+ for i, char in enumerate(self.character):
+ self.dict[char] = i
+ self.ignore_idx = self.dict['[PAD]']
+
+ @staticmethod
+ def START() -> int:
+ return TFMLabelConverter.list_token.index('[GO]')
+
+ @staticmethod
+ def END() -> int:
+ return TFMLabelConverter.list_token.index('[s]')
+
+ @staticmethod
+ def UNK() -> int:
+ return TFMLabelConverter.list_token.index('[UNK]')
+
+ @staticmethod
+ def PAD() -> int:
+ return TFMLabelConverter.list_token.index('[PAD]')
+
+ def encode(self, text, batch_max_length=25):
+ length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
+ batch_max_length += 1
+ batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(self.ignore_idx)
+ for i, t in enumerate(text):
+ text = list(t)
+
+ if len(text) > batch_max_length:
+ text = text[:(batch_max_length-1)]
+
+ text.append('[s]')
+ text = [self.dict[char] if char in self.dict else self.dict['[UNK]'] for char in text]
+ batch_text[i][0] = torch.LongTensor([self.dict['[GO]']])
+ batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
+ return (batch_text.to(self.device), torch.IntTensor(length).to(self.device))
+
+ def decode(self, text_index, token_level='word'):
+ """ convert text-index into text-label. """
+ texts = []
+ batch_size = text_index.shape[0]
+ for index in range(batch_size):
+ if token_level == 'word':
+ text = ' '.join([self.character[i] for i in text_index[index, :]])
+ else:
+ text = ''.join([self.character[i] for i in text_index[index, :]])
+ texts.append(text)
+ return texts
+
+ def detokenize(self, token_ids):
+ """convert token ids to list of token"""
+ b_toks = []
+ for tok in token_ids:
+ toks = []
+ for i in tok:
+ if self.character[i] == '[s]':
+ break
+ toks.append(self.character[i])
+ b_toks.append(toks)
+
+ return b_toks
+
+if __name__ == '__main__':
+ vocab = ['S', 'ố', ' ', '2', '5', '3', 'đ', 'ư', 'ờ', 'n', 'g', 'T', 'r', 'ầ', 'P', 'h', 'ú', ',', 'ị', 't', 'ấ', 'N', 'a', 'm', 'á', 'c', 'H', 'u', 'y', 'ệ', 'ả', 'i', 'D', 'ơ', '8', '9', 'Đ', 'B', 'ộ', 'L', 'ĩ', '6', 'Q', 'ậ', 'ì', 'ạ', 'ồ', 'C', 'í', 'M', '4', 'E', '/', 'K', 'p', '1', 'A', 'x', 'ặ', 'ễ', '0', 'â', 'à', 'ế', 'ừ', 'ê', '-', '7', 'o', 'V', 'ô', 'ã', 'G', 'ớ', 'Y', 'I', 'ề', 'ò', 'l', 'R', 'ỹ', 'ủ', 'X', "'", 'e', 'ắ', 'ổ', 'ằ', 'k', 's', '.', 'ợ', 'ù', 'ứ', 'ă', 'ỳ', 'ẵ', 'ý', 'ó', 'ẩ', 'ọ', 'J', 'ũ', 'ữ', 'ự', 'õ', 'ỉ', 'ỏ', 'v', 'd', 'Â', 'W', 'U', 'O', 'é', 'ở', 'ỷ', '(', ')', 'ử', 'è', 'ể', 'ụ', 'ỗ', 'F', 'q', 'ẻ', 'ỡ', 'b', 'ỵ', 'Ứ', '#', 'ẽ', 'Ô', 'Ê', 'Ơ', '+', 'z', 'Ấ', 'w', 'Z', '&', 'Á', '~', 'f', 'Ạ', 'Ắ', 'j', ':', 'Ă', '<', '>', 'ẹ', '_', 'À', 'Ị', 'Ư', 'Ễ']
+ text = [
+ "190B Trần Quang Khải, Phường Tân Định, Quận 1, TP Hồ Chí Minh",
+ "164/2B, Quốc lộ 1A, Phường Lê Bình, Quận Cái Răng, Cần Thơ",
+ "Cẩm Huy, Huyện Cẩm Xuyên, Hà Tĩnh"
+ ]
+ tfm_convert = TFMLabelConverter(vocab, 'cpu')
+ texts, lengths = tfm_convert.encode(text, 70)
+ print(texts)
+ for text in texts:
+ print('Encode', text)
+ text = text.unsqueeze(0)
+ decode_text = tfm_convert.decode(text, 'char')
+ print('Decode', decode_text)
\ No newline at end of file
diff --git a/HybridViT/recog_flow.py b/HybridViT/recog_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..0186d84e77e65fe192eb24bf74134e2f7cadd601
--- /dev/null
+++ b/HybridViT/recog_flow.py
@@ -0,0 +1,113 @@
+import re
+from typing import Any
+from collections import OrderedDict
+import re
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from PIL import Image
+from timm.models.resnetv2 import ResNetV2
+
+from .recognizers.build_model import Model
+from .module.converter import AttnLabelConverter, TFMLabelConverter
+from .helper import resize
+
+class MathRecognition(object):
+ def __init__(self, config, resizer):
+ self.args = config
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.args["device"] = device
+ self.device = device
+ self._prepare_vocab()
+ self.model = self._get_model()
+ self.resizer=resizer
+
+ def _mapping_ckpt(self, state_dict):
+ new_state_dict = OrderedDict()
+
+ for name, param in state_dict.items():
+ if name.startswith('Transformation'):
+ continue
+ elif name.startswith('FeatureExtraction'):
+ new_name = name.replace('FeatureExtraction', 'featextractor.FeatureExtraction')
+ new_state_dict[new_name] = param
+ elif name.startswith('SequenceModeling'):
+ new_name = name.replace('SequenceModeling', 'seqmodeler.SequenceModeling')
+ new_state_dict[new_name] = param
+ elif name.startswith('Prediction'):
+ new_name = name.replace('Prediction', 'predicter.Prediction')
+ new_state_dict[new_name] = param
+ else:
+ new_state_dict[name] = param
+
+ return new_state_dict
+
+ def _get_model(self):
+ model = Model(self.args)
+ state_dict = torch.load(self.args["weight_path"], map_location='cpu')
+ new_state_dict = self._mapping_ckpt(state_dict)
+ model.load_state_dict(new_state_dict)
+ model=model.eval()
+
+ if self.device == 'cuda':
+ num_gpu = torch.cuda.device_count()
+ if num_gpu > 1:
+ model = nn.DataParallel(model).to(self.device)
+ else:
+ model.to(self.device)
+
+ return model
+
+ def _prepare_vocab(self):
+ with open(self.args["vocab"], 'rt') as f:
+ for line in f:
+ self.args["character"] += [line.rstrip()]
+ f.close()
+
+ if 'Attn' in self.args['Prediction']['name']:
+ self.converter = AttnLabelConverter(self.args["character"], self.device)
+ else:
+ self.converter = TFMLabelConverter(self.args["character"], self.device)
+
+ self.args["num_class"] = len(self.converter.character)
+
+ def _preprocess(self, image: Image.Image):
+ img_tensor = resize(self.resizer, image, self.args)
+ return img_tensor
+
+ def _postprocess(self, s: str):
+ text_reg = r'(\\(operatorname|mathrm|mathbf|mathsf|mathit|mathfrak|mathnormal)\s?\*? {.*?})'
+ letter = '[a-zA-Z]'
+ noletter = '[\W_^\d]'
+ names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
+ s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
+ news = s
+
+ for space in ["hspace", "vspace"]:
+ match = re.finditer(space + " {(.*?)}", news)
+ if match:
+ new_l = ""
+ last = 0
+ for m in match:
+ new_l = new_l + news[last:m.start(1)] + m.group(1).replace(" ", "")
+ last = m.end(1)
+ new_l = new_l + news[last:]
+ news = new_l
+
+ return news
+
+ def __call__(self, image: Image.Image, name=None, *arg: Any, **kwargs):
+ assert image.mode == 'RGB', 'input image must be RGB image'
+ with torch.no_grad():
+ img_tensor = self._preprocess(image).to(self.device)
+ text_for_pred = torch.LongTensor(1, self.args["batch_max_length"] + 1).fill_(0).to(self.device)
+ preds_index, _, _ = self.model(img_tensor, text_for_pred, is_train=False, is_test=True)
+ pred_str = self.converter.decode(preds_index, self.args.get('token_level', 'word'))[0]
+
+ pred_EOS = pred_str.find('[s]')
+ pred_str = pred_str[:pred_EOS]
+
+ process_str = self._postprocess(pred_str)
+
+ return process_str
diff --git a/HybridViT/recognizers/__init__.py b/HybridViT/recognizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/HybridViT/recognizers/build_feat.py b/HybridViT/recognizers/build_feat.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfc781151a476253ee28af01dc84a6af2961a325
--- /dev/null
+++ b/HybridViT/recognizers/build_feat.py
@@ -0,0 +1,45 @@
+import torch
+import torch.nn as nn
+from ..module.component.feature_extractor.clova_impl import VGG_FeatureExtractor, ResNet_FeatureExtractor
+
+
+class FeatExtractorBuilder(nn.Module):
+ def __init__(self, flow: dict, config):
+ super().__init__()
+ self.config = config
+ self.flow = flow
+ self.feat_name = flow['Feat']
+
+ if self.feat_name != 'None':
+ mean_height = config['FeatureExtraction']['params'].pop('mean_height', True)
+
+ if self.feat_name == 'VGG':
+ self.FeatureExtraction = VGG_FeatureExtractor(**config['FeatureExtraction']['params'])
+ self.FeatureExtraction_output = config['FeatureExtraction']['params']['output_channel']
+ elif self.feat_name == 'ResNet':
+ self.FeatureExtraction = ResNet_FeatureExtractor(**config['FeatureExtraction']['params'])
+ self.FeatureExtraction_output = config['FeatureExtraction']['params']['output_channel']
+
+ if mean_height:
+ self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
+ else:
+ self.proj_feat = nn.Linear(self.FeatureExtraction_output*3, self.FeatureExtraction_output)
+ else:
+ if flow['Seq'] not in ['ViT', 'MS_ViT', 'MS_ViTV2', 'MS_ViTV3', 'ViG']:
+ raise Exception('No FeatureExtraction module specified')
+ else:
+ self.FeatureExtraction = nn.Identity()
+
+ def forward(self, input):
+ visual_feature = self.FeatureExtraction(input)
+
+ if self.flow['Seq'] in ['BiLSTM', 'BiLSTM_3L']:
+ if hasattr(self, 'AdaptiveAvgPool'):
+ visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, 1]
+ visual_feature = visual_feature.squeeze(3)
+ else:
+ visual_feature = visual_feature.permute(0, 3, 1, 2)
+ visual_feature = visual_feature.flatten(start_dim=-2) # [b, c, h, w] -> [b, w, c*h]
+ visual_feature = self.proj_feat(visual_feature)
+
+ return visual_feature
diff --git a/HybridViT/recognizers/build_model.py b/HybridViT/recognizers/build_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6e4be9560c20c1291279a16c309a1add19d77fe
--- /dev/null
+++ b/HybridViT/recognizers/build_model.py
@@ -0,0 +1,82 @@
+import torch.nn as nn
+from .build_feat import FeatExtractorBuilder
+from .build_seq import SeqModelingBuilder
+from .build_pred import PredictBuilder
+
+class Model(nn.Module):
+ def __init__(self, opt):
+ super(Model, self).__init__()
+ self.opt = opt
+
+ stages = {
+ 'Feat': opt['FeatureExtraction']['name'],
+ 'Seq': opt['SequenceModeling']['name'],
+ 'Pred': opt['Prediction']['name'],
+ }
+ self.stages = stages
+ if stages['Seq'].__contains__("Vi"): assert stages['Feat'] == 'None'
+
+ """ FeatureExtraction """
+ self.featextractor = FeatExtractorBuilder(stages, opt)
+ FeatureExtraction_output = getattr(self.featextractor, 'FeatureExtraction_output', None)
+
+ """ Sequence modeling"""
+ self.seqmodeler = SeqModelingBuilder(stages, opt, FeatureExtraction_output)
+ SequenceModeling_output = getattr(self.seqmodeler, 'SequenceModeling_output', None)
+
+ """ Prediction """
+ self.predicter = PredictBuilder(stages, opt, SequenceModeling_output)
+
+ def forward_encoder(self, input, *args, **kwargs):
+ """ Feature extraction stage """
+ visual_feature = self.featextractor(input)
+ """ Sequence modeling stage """
+ contextual_feature, output_shape , feat_pad = self.seqmodeler(visual_feature, *args, **kwargs)
+ return contextual_feature, output_shape, feat_pad
+
+ def forward_decoder(self, contextual_feature, text, is_train=True,
+ is_test=False, rtl_text=None
+ ):
+ """ Prediction stage """
+ prediction, logits, decoder_attn, addition_outputs = self.predicter(contextual_feature, text, is_train, is_test, rtl_text)
+
+ return prediction, logits, decoder_attn, addition_outputs
+
+ def forward(self, input, text, is_train=True, is_test=False, rtl_text=None):
+ contextual_feature, output_shape, feat_pad = self.forward_encoder(input)
+ prediction, logits, decoder_attn, addition_outputs = self.forward_decoder(
+ contextual_feature, text=text, is_train=is_train, is_test=is_test, rtl_text=rtl_text
+ )
+
+ if decoder_attn is not None and output_shape is not None:
+ if self.stages['Pred'] == 'Attn' and self.stages['Seq'] == 'ViT':
+ decoder_attn = decoder_attn[:, 1:]
+ decoder_attn = decoder_attn.reshape(-1, output_shape[0], output_shape[1])
+
+ addition_outputs.update(
+ {
+ 'decoder_attn': decoder_attn,
+ 'feat_width': output_shape[0],
+ 'feat_height': output_shape[1],
+ 'feat_pad': feat_pad,
+ }
+ )
+
+ return prediction, logits, addition_outputs
+
+if __name__ == '__main__':
+ import yaml
+ import torch
+
+ with open('/media/huynhtruc0309/DATA/Math_Expression/my_source/Math_Recognition/config/train/paper_experiments/report_paper/best_model_noaugment/experiment_1806.yaml', 'r') as f:
+ config = yaml.safe_load(f)
+
+ config['num_class'] = 499
+ config['device'] = 'cpu'
+ model = Model(config)
+
+ a = torch.rand(1, 1, 32, 224)
+
+ output = model(a)
+
+ print('pred', output[0].shape)
diff --git a/HybridViT/recognizers/build_pred.py b/HybridViT/recognizers/build_pred.py
new file mode 100644
index 0000000000000000000000000000000000000000..8caa1be4de6b79477b871a6d00b13ed515c4d3fd
--- /dev/null
+++ b/HybridViT/recognizers/build_pred.py
@@ -0,0 +1,61 @@
+import torch
+import torch.nn as nn
+from ..module.component.prediction_head import (Attention,
+ AttentionV2,
+ TransformerPrediction
+ )
+
+
+class PredictBuilder(nn.Module):
+ def __init__(self, flow, config, SequenceModeling_output):
+ super().__init__()
+ self.flow =flow
+ self.config=config
+ if flow['Pred'] == 'CTC':
+ self.Prediction = nn.Linear(SequenceModeling_output, config['num_class'])
+
+ elif flow['Pred'] == 'Attn':
+ config['Prediction']['params']['num_classes'] = config['num_class']
+ config['Prediction']['params']['device'] = config['device']
+ self.Prediction = Attention(
+ **config['Prediction']['params']
+ )
+ elif flow['Pred'] == 'Attnv2':
+ config['Prediction']['params']['num_classes'] = config['num_class']
+ config['Prediction']['params']['device'] = config['device']
+ self.Prediction = AttentionV2(
+ **config['Prediction']['params']
+ )
+ elif flow['Pred'] == 'Multistage_Attn':
+ config['Prediction']['params']['num_classes'] = config['num_class']
+ config['Prediction']['params']['device'] = config['device']
+ self.Prediction = AttentionV2(
+ **config['Prediction']['params']
+ )
+ elif flow['Pred'] == 'TFM':
+ config['Prediction']['params']['num_classes'] = config['num_class']
+ config['Prediction']['params']['device'] = config['device']
+ self.Prediction = TransformerPrediction(
+ **config['Prediction']['params']
+ )
+ else:
+ raise ValueError('Prediction name is not suppported')
+
+ def forward(self, contextual_feature, text, is_train=True, is_test=False, rtl_text=None):
+ beam_size = self.config.get('beam_size', 1)
+
+ addition_outputs = {}
+ decoder_attn = None
+
+ if self.flow['Pred'] == 'CTC':
+ prediction = self.Prediction(contextual_feature.contiguous())
+
+ elif self.flow['Pred'] in ['Attn', 'Attnv2']:
+ prediction, logits, decoder_attn = self.Prediction(beam_size, contextual_feature.contiguous(), text, is_train=is_train,
+ is_test=is_test, batch_max_length=self.config['batch_max_length'])
+
+ elif self.flow['Pred'] == 'TFM':
+ prediction, logits = self.Prediction(beam_size, contextual_feature.contiguous(), text, is_test)
+ self.Prediction.reset_beam()
+
+ return prediction, logits, decoder_attn, addition_outputs
diff --git a/HybridViT/recognizers/build_seq.py b/HybridViT/recognizers/build_seq.py
new file mode 100644
index 0000000000000000000000000000000000000000..eddcc7c70c66d1bb7588ea41b2938dba465f4a8c
--- /dev/null
+++ b/HybridViT/recognizers/build_seq.py
@@ -0,0 +1,60 @@
+import torch.nn as nn
+from ..module.component.seq_modeling import BidirectionalLSTM, create_vit_modeling
+from ..module.component.seq_modeling.bilstm import BiLSTM_Seq_Modeling
+from ..module.component.common import GatedSum
+from ..module.component.common import PositionalEncoding2D, PositionalEncoding1D
+
+
+class SeqModelingBuilder(nn.Module):
+ def __init__(self, flow: dict, config, FeatureExtraction_output):
+ super().__init__()
+ self.config = config
+ self.flow = flow
+
+ if flow['Seq'] == 'BiLSTM':
+ hidden_size = config['SequenceModeling']['params']['hidden_size']
+ use_pos_enc = config['SequenceModeling']['params'].get('pos_enc', False)
+
+ if use_pos_enc:
+ self.image_positional_encoder = PositionalEncoding1D(hidden_size)
+ self.gated = GatedSum(hidden_size)
+
+ self.SequenceModeling = nn.Sequential(
+ BidirectionalLSTM(FeatureExtraction_output, hidden_size, hidden_size),
+ BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
+
+ self.SequenceModeling_output = hidden_size
+
+ elif flow['Seq'] == 'BiLSTM_3L':
+ hidden_size = config['SequenceModeling']['params']['hidden_size']
+ self.SequenceModeling = BiLSTM_Seq_Modeling(3, FeatureExtraction_output, hidden_size, hidden_size)
+
+ self.SequenceModeling_output = hidden_size
+
+ elif flow['Seq'] == 'ViT':
+ assert config['max_dimension'] is not None, "ViT encoder require exact height or max height and max width"
+ self.SequenceModeling = create_vit_modeling(config)
+ else:
+ print('No SequenceModeling module specified')
+ if flow['Pred'] == 'TFM':
+ self.image_positional_encoder = PositionalEncoding2D(FeatureExtraction_output)
+
+ self.SequenceModeling_output = FeatureExtraction_output
+
+ def forward(self, visual_feature, *args, **kwargs):
+ output_shape = None
+ pad_info = None
+
+ if self.flow['Seq'] in ['BiLSTM', 'BiLSTM_3L']:
+ contextual_feature = self.SequenceModeling(visual_feature)
+
+ if hasattr(self, 'image_positional_encoder'):
+ assert len(contextual_feature.shape) == 3
+ contextual_feature_1 = self.image_positional_encoder(visual_feature.permute(1, 0, 2))
+ contextual_feature_1 = contextual_feature_1.permute(1, 0, 2)
+ contextual_feature = self.gated(contextual_feature_1, contextual_feature)
+
+ elif self.flow['Seq'] == 'ViT':
+ contextual_feature, pad_info, _ = self.SequenceModeling(visual_feature)
+
+ return contextual_feature, output_shape, pad_info
\ No newline at end of file
diff --git a/HybridViT/resizer.py b/HybridViT/resizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ScanSSD/IOU_lib/BoundingBox.py b/ScanSSD/IOU_lib/BoundingBox.py
new file mode 100644
index 0000000000000000000000000000000000000000..12b82673302031367c28f1fcd25e6f8250369d24
--- /dev/null
+++ b/ScanSSD/IOU_lib/BoundingBox.py
@@ -0,0 +1,164 @@
+from .iou_utils import *
+
+class BoundingBox:
+ def __init__(self,
+ imageName,
+ classId,
+ x,
+ y,
+ w,
+ h,
+ typeCoordinates=CoordinatesType.Absolute,
+ imgSize=None,
+ bbType=BBType.GroundTruth,
+ classConfidence=None,
+ format=BBFormat.XYWH):
+ """Constructor.
+ Args:
+ imageName: String representing the image name.
+ classId: String value representing class id.
+ x: Float value representing the X upper-left coordinate of the bounding box.
+ y: Float value representing the Y upper-left coordinate of the bounding box.
+ w: Float value representing the width bounding box.
+ h: Float value representing the height bounding box.
+ typeCoordinates: (optional) Enum (Relative or Absolute) represents if the bounding box
+ coordinates (x,y,w,h) are absolute or relative to size of the image. Default:'Absolute'.
+ imgSize: (optional) 2D vector (width, height)=>(int, int) represents the size of the
+ image of the bounding box. If typeCoordinates is 'Relative', imgSize is required.
+ bbType: (optional) Enum (Groundtruth or Detection) identifies if the bounding box
+ represents a ground truth or a detection. If it is a detection, the classConfidence has
+ to be informed.
+ classConfidence: (optional) Float value representing the confidence of the detected
+ class. If detectionType is Detection, classConfidence needs to be informed.
+ format: (optional) Enum (BBFormat.XYWH or BBFormat.XYX2Y2) indicating the format of the
+ coordinates of the bounding boxes. BBFormat.XYWH:
+ BBFormat.XYX2Y2: .
+ """
+ self._imageName = imageName
+ self._typeCoordinates = typeCoordinates
+ if typeCoordinates == CoordinatesType.Relative and imgSize is None:
+ raise IOError(
+ 'Parameter \'imgSize\' is required. It is necessary to inform the image size.')
+ if bbType == BBType.Detected and classConfidence is None:
+ raise IOError(
+ 'For bbType=\'Detection\', it is necessary to inform the classConfidence value.')
+ # if classConfidence != None and (classConfidence < 0 or classConfidence > 1):
+ # raise IOError('classConfidence value must be a real value between 0 and 1. Value: %f' %
+ # classConfidence)
+
+ self._classConfidence = classConfidence
+ self._bbType = bbType
+ self._classId = classId
+ self._format = format
+
+ # If relative coordinates, convert to absolute values
+ # For relative coords: (x,y,w,h)=(X_center/img_width , Y_center/img_height)
+ if (typeCoordinates == CoordinatesType.Relative):
+ (self._x, self._y, self._w, self._h) = convertToAbsoluteValues(imgSize, (x, y, w, h))
+ self._width_img = imgSize[0]
+ self._height_img = imgSize[1]
+ if format == BBFormat.XYWH:
+ self._x2 = self._w
+ self._y2 = self._h
+ self._w = self._x2 - self._x
+ self._h = self._y2 - self._y
+ else:
+ raise IOError(
+ 'For relative coordinates, the format must be XYWH (x,y,width,height)')
+ # For absolute coords: (x,y,w,h)=real bb coords
+ else:
+ self._x = x
+ self._y = y
+ if format == BBFormat.XYWH:
+ self._w = w
+ self._h = h
+ self._x2 = self._x + self._w
+ self._y2 = self._y + self._h
+ else: # format == BBFormat.XYX2Y2: .
+ self._x2 = w
+ self._y2 = h
+ self._w = self._x2 - self._x
+ self._h = self._y2 - self._y
+ if imgSize is None:
+ self._width_img = None
+ self._height_img = None
+ else:
+ self._width_img = imgSize[0]
+ self._height_img = imgSize[1]
+
+ def __str__(self):
+ return "{} {} {} {} {}".format(self.getImageName(), self._x, self._y, self._w, self._h)
+
+ def getAbsoluteBoundingBox(self, format=BBFormat.XYWH):
+ if format == BBFormat.XYWH:
+ return (self._x, self._y, self._w, self._h)
+ elif format == BBFormat.XYX2Y2:
+ return (self._x, self._y, self._x2, self._y2)
+
+ def getRelativeBoundingBox(self, imgSize=None):
+ if imgSize is None and self._width_img is None and self._height_img is None:
+ raise IOError(
+ 'Parameter \'imgSize\' is required. It is necessary to inform the image size.')
+ if imgSize is None:
+ return convertToRelativeValues((imgSize[0], imgSize[1]),
+ (self._x, self._y, self._w, self._h))
+ else:
+ return convertToRelativeValues((self._width_img, self._height_img),
+ (self._x, self._y, self._w, self._h))
+
+ def getImageName(self):
+ return self._imageName
+
+ def getConfidence(self):
+ return self._classConfidence
+
+ def getFormat(self):
+ return self._format
+
+ def getClassId(self):
+ return self._classId
+
+ def getImageSize(self):
+ return (self._width_img, self._height_img)
+
+ def getCoordinatesType(self):
+ return self._typeCoordinates
+
+ def getBBType(self):
+ return self._bbType
+
+ @staticmethod
+ def compare(det1, det2):
+ det1BB = det1.getAbsoluteBoundingBox()
+ det1ImgSize = det1.getImageSize()
+ det2BB = det2.getAbsoluteBoundingBox()
+ det2ImgSize = det2.getImageSize()
+
+ if det1.getClassId() == det2.getClassId() and \
+ det1.classConfidence == det2.classConfidenc() and \
+ det1BB[0] == det2BB[0] and \
+ det1BB[1] == det2BB[1] and \
+ det1BB[2] == det2BB[2] and \
+ det1BB[3] == det2BB[3] and \
+ det1ImgSize[0] == det1ImgSize[0] and \
+ det2ImgSize[1] == det2ImgSize[1]:
+ return True
+ return False
+
+ @staticmethod
+ def clone(boundingBox):
+ absBB = boundingBox.getAbsoluteBoundingBox(format=BBFormat.XYWH)
+ # return (self._x,self._y,self._x2,self._y2)
+ newBoundingBox = BoundingBox(
+ boundingBox.getImageName(),
+ boundingBox.getClassId(),
+ absBB[0],
+ absBB[1],
+ absBB[2],
+ absBB[3],
+ typeCoordinates=boundingBox.getCoordinatesType(),
+ imgSize=boundingBox.getImageSize(),
+ bbType=boundingBox.getBBType(),
+ classConfidence=boundingBox.getConfidence(),
+ format=BBFormat.XYWH)
+ return newBoundingBox
diff --git a/ScanSSD/IOU_lib/Evaluator.py b/ScanSSD/IOU_lib/Evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..d16aecde1e09a382b1fa29eaa0bb5e29bec2a2d8
--- /dev/null
+++ b/ScanSSD/IOU_lib/Evaluator.py
@@ -0,0 +1,87 @@
+###########################################################################################
+# #
+# Evaluator class: Implements the most popular metrics for object detection #
+# #
+# Developed by: Rafael Padilla (rafael.padilla@smt.ufrj.br) #
+# SMT - Signal Multimedia and Telecommunications Lab #
+# COPPE - Universidade Federal do Rio de Janeiro #
+# Last modification: Oct 9th 2018 #
+###########################################################################################
+
+import os
+import sys
+from collections import Counter
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from .BoundingBox import *
+from .iou_utils import *
+
+
+class Evaluator:
+
+ # For each detections, calculate IOU with reference
+ @staticmethod
+ def _getAllIOUs(reference, detections):
+ ret = []
+ bbReference = reference.getAbsoluteBoundingBox(BBFormat.XYX2Y2)
+ # img = np.zeros((200,200,3), np.uint8)
+ for d in detections:
+ bb = d.getAbsoluteBoundingBox(BBFormat.XYX2Y2)
+ iou = Evaluator.iou(bbReference, bb)
+ # Show blank image with the bounding boxes
+ # img = add_bb_into_image(img, d, color=(255,0,0), thickness=2, label=None)
+ # img = add_bb_into_image(img, reference, color=(0,255,0), thickness=2, label=None)
+ ret.append((iou, reference, d)) # iou, reference, detection
+ # cv2.imshow("comparing",img)
+ # cv2.waitKey(0)
+ # cv2.destroyWindow("comparing")
+ return sorted(ret, key=lambda i: i[0], reverse=True) # sort by iou (from highest to lowest)
+
+ @staticmethod
+ def iou(boxA, boxB):
+ # if boxes dont intersect
+ if Evaluator._boxesIntersect(boxA, boxB) is False:
+ return 0
+ interArea = Evaluator._getIntersectionArea(boxA, boxB)
+ union = Evaluator._getUnionAreas(boxA, boxB, interArea=interArea)
+ # intersection over union
+ iou = interArea / union
+ assert iou >= 0
+ return iou
+
+ # boxA = (Ax1,Ay1,Ax2,Ay2)
+ # boxB = (Bx1,By1,Bx2,By2)
+ @staticmethod
+ def _boxesIntersect(boxA, boxB):
+ if boxA[0] > boxB[2]:
+ return False # boxA is right of boxB
+ if boxB[0] > boxA[2]:
+ return False # boxA is left of boxB
+ if boxA[3] < boxB[1]:
+ return False # boxA is above boxB
+ if boxA[1] > boxB[3]:
+ return False # boxA is below boxB
+ return True
+
+ @staticmethod
+ def _getIntersectionArea(boxA, boxB):
+ xA = max(boxA[0], boxB[0])
+ yA = max(boxA[1], boxB[1])
+ xB = min(boxA[2], boxB[2])
+ yB = min(boxA[3], boxB[3])
+ # intersection area
+ return (xB - xA + 1) * (yB - yA + 1)
+
+ @staticmethod
+ def _getUnionAreas(boxA, boxB, interArea=None):
+ area_A = Evaluator._getArea(boxA)
+ area_B = Evaluator._getArea(boxB)
+ if interArea is None:
+ interArea = Evaluator._getIntersectionArea(boxA, boxB)
+ return float(area_A + area_B - interArea)
+
+ @staticmethod
+ def _getArea(box):
+ return (box[2] - box[0] + 1) * (box[3] - box[1] + 1)
diff --git a/ScanSSD/IOU_lib/IOUevaluater.py b/ScanSSD/IOU_lib/IOUevaluater.py
new file mode 100644
index 0000000000000000000000000000000000000000..16a1af36412ddd7c83397edec3e5e121fcefbecf
--- /dev/null
+++ b/ScanSSD/IOU_lib/IOUevaluater.py
@@ -0,0 +1,433 @@
+from zipfile import ZipFile
+import os
+from .Evaluator import *
+from utils import *
+import copy
+import argparse
+import sys
+import ntpath
+#import cStringIO
+from io import BytesIO
+import shutil
+
+
+def read_file(filename, bboxes, flag):
+ '''
+ Parses the input .csv file into map where key as page number and value as a list of bounding box objects
+ corresponding to each math region in the file.
+ :param filename: .csv file containing math regions
+ :param bboxes: Map>
+ :return:
+ '''
+ fh1 = open(filename, "r")
+ prev_page = -1
+ counter = 1
+ for line in fh1:
+ line = line.replace("\n", "")
+ if line.replace(' ', '') == '':
+ continue
+ splitLine = line.split(",")
+ idClass = float(splitLine[0])
+ if prev_page == -1:
+ prev_page = idClass
+ else:
+ if idClass != prev_page:
+ counter = 1
+ prev_page = idClass
+ x = float(splitLine[1])
+ y = float(splitLine[2])
+ x2 = float(splitLine[3])
+ y2 = float(splitLine[4])
+ bb = BoundingBox(
+ flag+"_"+str(counter),
+ 1,
+ x,
+ y,
+ x2,
+ y2,
+ CoordinatesType.Absolute, (200, 200),
+ BBType.GroundTruth,
+ format=BBFormat.XYX2Y2)
+ counter += 1
+ #print(counter)
+ if idClass not in bboxes:
+ bboxes[idClass] = []
+ bboxes[idClass].append(bb)
+
+ fh1.close()
+
+
+def extract_zipfile(zip_filename, target_dir):
+ '''
+ Extract zip file into the target directory
+ :param zip_filename: full-file-path of the zip-file
+ :param target_dir: target-dir to extract contents of zip-file
+ :return:
+ '''
+ with ZipFile(zip_filename, 'r') as zip:
+ # extracting all the files
+ print('Extracting all the files now...')
+ zip.extractall(target_dir)
+ print('Done!')
+
+
+def create_doc_bboxes_map(dir_path,flag):
+ '''
+ Reads all files recursively in directory path and and returns a map containing bboxes for each page in each math
+ file in directory.
+ :param dir_path: full directory path containing math files
+ :return: Map>>
+ '''
+ pdf_bboxes_map = {}
+
+ for filename in os.listdir(dir_path):
+ full_filepath = os.path.join(dir_path, filename)
+ filename_key = os.path.splitext(os.path.basename(full_filepath))[0]
+ #print(full_filepath)
+ if (full_filepath.startswith(".")) or (not (full_filepath.endswith(".csv") or full_filepath.endswith(".math"))):
+ continue
+ bboxes_map = {}
+
+ if os.path.isdir(full_filepath):
+ continue
+
+ try:
+ read_file(full_filepath, bboxes_map,flag)
+ except Exception as e:
+ print('exception occurred in reading file',full_filepath, str(e))
+
+ #if len(bboxes_map)==0:
+ # raise ValueError("Empty ground truths file or not in valid format")
+ pdf_bboxes_map[filename_key] = copy.deepcopy(bboxes_map)
+
+ return pdf_bboxes_map
+
+def unique_values(input_dict):
+ #return ground truth boxes that have same det boxes
+ pred_list=[]
+ repair_keys=[]
+ for value in input_dict.values():
+ if value[1] in pred_list: #preds.append(value)
+ gts=[k for k,v in input_dict.items() if v[1] == value[1]]
+ #print('pair length',len(gts))
+ repair_keys.append(gts)
+ pred_list.append(value[1])
+
+ return repair_keys
+
+def generate_validpairs(pairs):
+ newpairs=[]
+ for pair in pairs:
+ if len(pair)>2:
+ for i in range(len(pair)-1):
+ newpair=(pair[i],pair[i+1])
+ if newpair not in newpairs:newpairs.append(newpair)
+
+ elif pair not in newpairs: newpairs.append(pair)
+ return newpairs
+
+def fix_preds(input_dict,keyPairs,thre):
+
+ validPairs=generate_validpairs(keyPairs)
+
+ for pair in validPairs:
+ #check if both pair exists"
+ if pair[0] not in list(input_dict.keys()) or pair[1] not in list(input_dict.keys()):
+ continue
+ val0=input_dict[pair[0]][0]
+ val1=input_dict[pair[1]][0]
+ if val0>=val1: #change prediction for second pair
+ values=input_dict[pair[1]]
+ newprob=values[2][1]
+ if newprobval0: #change prediction for first pair
+ values=input_dict[pair[0]]
+ newprob=values[2][1]
+ if newprob> for ground truth bboxes
+ :param det_page_bboxes_map: Map> for detection bboxes
+ :return:
+ '''
+ evaluator = Evaluator()
+
+ correct_pred_coarse=0
+ correct_pred_fine=0
+
+ pdf_gt_boxes=0
+ pdf_det_boxes=0
+
+ coarse_keys = {}
+ fine_keys = {}
+
+ for page_num in gt_page_bboxes_map:
+ if page_num not in det_page_bboxes_map:
+ print('Detections not found for page', str(page_num + 1), ' in', pdf_name)
+ continue
+ gt_boxes = gt_page_bboxes_map[page_num]
+ det_boxes = det_page_bboxes_map[page_num]
+
+ pdf_gt_boxes+=len(gt_boxes)
+ pdf_det_boxes+=len(det_boxes)
+
+ pred_dict={}
+ for gt_box in gt_boxes:
+ ious = evaluator._getAllIOUs(gt_box, det_boxes)
+ preds=[]
+ labels=[]
+ for i in range(len(ious)):
+ preds.append(round(ious[i][0],2))
+ labels.append(ious[i][2].getImageName())
+
+ pred_dict[gt_box.getImageName()]=preds[0],labels[0],preds,labels
+
+ coarse,coarse_dict=count_true_box(copy.deepcopy(pred_dict),0.5)
+ fine,fine_dict=count_true_box(copy.deepcopy(pred_dict),0.75)
+
+ coarse_keys[page_num] = coarse_dict.keys()
+ fine_keys[page_num] = fine_dict.keys()
+
+ #count correct preds for coarse 0.5 and fine 0.75 in one page
+ correct_pred_coarse= correct_pred_coarse+coarse
+ correct_pred_fine= correct_pred_fine+fine
+ #write iou per page
+ if outdir:
+ out_file = open(os.path.join(outdir,pdf_name.split(".csv")[0]+"_"+str(page_num)+"_eval.txt"), "w")
+ out_file.write('#page num '+str(page_num)+", gt_box:"+str(len(gt_boxes))+
+ ", pred_box:"+str(len(det_boxes))+"\n")
+ out_file.write('\n')
+ out_file.write('#COARSE DETECTION (iou>0.5):\n#number of correct prediction:'+ str(coarse)+ '\n#correctly detected:'+
+ str(list(coarse_dict.keys()))+'\n')
+ out_file.write('\n')
+ out_file.write('#FINE DETECTION (iou>0.75):\n#number of correct prediction:'+ str(fine)+ '\n#correctly detected:'+
+ str(list(fine_dict.keys()))+'\n')
+ out_file.write('\n')
+ out_file.write('#Sorted IOU scores for each GT box:\n')
+ for gt_box in gt_boxes:
+ ious = evaluator._getAllIOUs(gt_box, det_boxes)
+ out_file.write(gt_box.getImageName()+",")
+ for i in range(len(ious)-1):
+ out_file.write("("+str(round(ious[i][0],2))+" "+ str(ious[i][2].getImageName())+"),")
+ out_file.write( "("+str(round(ious[-1][0],2))+" "+ str(ious[-1][2].getImageName())+")\n" )
+ out_file.close()
+
+ return correct_pred_coarse, correct_pred_fine, pdf_gt_boxes, pdf_det_boxes, coarse_keys, fine_keys
+
+def count_box(input_dict):
+ count=0
+ for pdf in input_dict.values():
+ for page in pdf.values():
+ count+=len(page)
+
+ return count
+
+# Zip every uploading files
+def archive_iou_txt(username, task_id, sub_id,userpath):
+
+ inputdir=os.path.join(userpath,'iouEval_stats')
+
+ if not os.path.exists(inputdir):
+ print('No txt file is generated for IOU evaluation')
+ pass
+
+ dest_uploader = 'IOU_stats_archive'
+ dest_uploader = os.path.join(userpath, dest_uploader)
+
+ if not os.path.exists(dest_uploader):
+ os.makedirs(dest_uploader)
+
+ zip_file_name = '/' + task_id + '_' + sub_id
+ shutil.make_archive(dest_uploader + zip_file_name, 'zip', inputdir)
+
+ # return '/media/' + dest_uploader
+
+def write_html(gtFile,resultsFile,info,scores,destFile):
+
+ destFile.write('')
+ destFile.write('')
+ destFile.write('')
+ #writeCSS(destFile)
+ destFile.write ("
CROHME 2019
Formula Detection Results ( TASK 3 )
")
+ destFile.write("Submitted Files
Output: "+ ntpath.basename(resultsFile) +"
")
+ destFile.write ("
Ground-truth: " + ntpath.basename(gtFile) + "
")
+ if info['allGTbox'] == 0:
+ sys.stderr.write("Error : no sample in this GT list !\n")
+ exit(-1)
+ #all detection and gt boxes
+ destFile.write ("
Number of ground truth bounding boxes: " + str(info['allGTbox']) + " Number of detected bounding boxes: " + str(info['allDet']))
+ destFile.write ("