duycse1603 commited on
Commit
6163604
·
1 Parent(s): 16db6af

[Add] source

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. HybridViT/beam.py +131 -0
  2. HybridViT/helper.py +182 -0
  3. HybridViT/module/component/common/__init__.py +6 -0
  4. HybridViT/module/component/common/conv.py +148 -0
  5. HybridViT/module/component/common/droppath.py +36 -0
  6. HybridViT/module/component/common/gated_sum.py +36 -0
  7. HybridViT/module/component/common/mae_posembed.py +72 -0
  8. HybridViT/module/component/common/maxout.py +22 -0
  9. HybridViT/module/component/common/postional_encoding.py +226 -0
  10. HybridViT/module/component/feature_extractor/__init__.py +0 -0
  11. HybridViT/module/component/feature_extractor/addon_module/__init__.py +2 -0
  12. HybridViT/module/component/feature_extractor/addon_module/aspp.py +59 -0
  13. HybridViT/module/component/feature_extractor/addon_module/visual_attention.py +325 -0
  14. HybridViT/module/component/feature_extractor/clova_impl/__init__.py +2 -0
  15. HybridViT/module/component/feature_extractor/clova_impl/resnet.py +262 -0
  16. HybridViT/module/component/feature_extractor/clova_impl/vgg.py +27 -0
  17. HybridViT/module/component/feature_extractor/helpers.py +76 -0
  18. HybridViT/module/component/feature_extractor/vgg.py +96 -0
  19. HybridViT/module/component/prediction_head/__init__.py +5 -0
  20. HybridViT/module/component/prediction_head/addon_module/__init__.py +3 -0
  21. HybridViT/module/component/prediction_head/addon_module/attention1D.py +218 -0
  22. HybridViT/module/component/prediction_head/addon_module/attention2D.py +88 -0
  23. HybridViT/module/component/prediction_head/addon_module/position_encoding.py +27 -0
  24. HybridViT/module/component/prediction_head/seq2seq.py +268 -0
  25. HybridViT/module/component/prediction_head/seq2seq_v2.py +218 -0
  26. HybridViT/module/component/prediction_head/tfm.py +207 -0
  27. HybridViT/module/component/seq_modeling/__init__.py +2 -0
  28. HybridViT/module/component/seq_modeling/addon_module/__init__.py +1 -0
  29. HybridViT/module/component/seq_modeling/addon_module/patchembed.py +161 -0
  30. HybridViT/module/component/seq_modeling/bilstm.py +33 -0
  31. HybridViT/module/component/seq_modeling/vit/utils.py +59 -0
  32. HybridViT/module/component/seq_modeling/vit/vision_transformer.py +184 -0
  33. HybridViT/module/component/seq_modeling/vit_encoder.py +276 -0
  34. HybridViT/module/converter/__init__.py +3 -0
  35. HybridViT/module/converter/attn_converter.py +71 -0
  36. HybridViT/module/converter/builder.py +6 -0
  37. HybridViT/module/converter/tfm_converter.py +90 -0
  38. HybridViT/recog_flow.py +113 -0
  39. HybridViT/recognizers/__init__.py +0 -0
  40. HybridViT/recognizers/build_feat.py +45 -0
  41. HybridViT/recognizers/build_model.py +82 -0
  42. HybridViT/recognizers/build_pred.py +61 -0
  43. HybridViT/recognizers/build_seq.py +60 -0
  44. HybridViT/resizer.py +0 -0
  45. ScanSSD/IOU_lib/BoundingBox.py +164 -0
  46. ScanSSD/IOU_lib/Evaluator.py +87 -0
  47. ScanSSD/IOU_lib/IOUevaluater.py +433 -0
  48. ScanSSD/IOU_lib/__init__.py +0 -0
  49. ScanSSD/IOU_lib/iou_utils.py +113 -0
  50. ScanSSD/README.md +11 -0
HybridViT/beam.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List
3
+ from einops import rearrange, repeat
4
+ from typing import Optional
5
+
6
+ class Hypothesis:
7
+ seq: List[int]
8
+ score: float
9
+ attn_weights: List[float]
10
+
11
+ def __init__(
12
+ self,
13
+ seq_tensor: torch.LongTensor,
14
+ score: float,
15
+ weights: Optional[torch.FloatTensor] = None
16
+ ) -> None:
17
+ raw_seq = seq_tensor.tolist()
18
+
19
+ self.seq = raw_seq
20
+ self.score = score
21
+ if weights:
22
+ self.attn_weights = weights.tolist()
23
+ assert len(self.attn_weights) == len(self.seq)
24
+ else:
25
+ self.attn_weights = None
26
+
27
+ def __len__(self):
28
+ if len(self.seq) != 0:
29
+ return len(self.seq)
30
+ else:
31
+ return 1
32
+
33
+ def __str__(self):
34
+ return f"seq: {self.seq}, score: {self.score}, weight: {self.attn_weights}"
35
+
36
+
37
+ class Beam:
38
+ def __init__(self,
39
+ start_w=1,
40
+ stop_w=2,
41
+ ignore_w=0,
42
+ max_len=150,
43
+ viz_attn=False,
44
+ device='cuda'
45
+ ):
46
+ self.stop_w = stop_w
47
+ self.start_w = start_w
48
+
49
+ self.hypotheses = torch.full(
50
+ (1, max_len + 2),
51
+ fill_value=ignore_w,
52
+ dtype=torch.long,
53
+ device=device,
54
+ )
55
+ if viz_attn:
56
+ self.hyp_alpha = torch.ones(1, max_len + 2, dtype=torch.float, device=device)
57
+
58
+ self.hypotheses[:, 0] = start_w
59
+ self.hyp_scores = torch.zeros(1, dtype=torch.float, device=device)
60
+ self.completed_hypotheses: List[Hypothesis] = []
61
+ self.device = device
62
+ self.viz_attn = viz_attn
63
+
64
+ def advance(self, next_log_probs, step, beam_size):
65
+ vocab_size = next_log_probs.shape[1]
66
+ live_hyp_num = beam_size - len(self.completed_hypotheses)
67
+ exp_hyp_scores = repeat(self.hyp_scores, "b -> b e", e=vocab_size)
68
+ continuous_hyp_scores = rearrange(exp_hyp_scores + next_log_probs, "b e -> (b e)")
69
+ top_cand_hyp_scores, top_cand_hyp_pos = torch.topk(
70
+ continuous_hyp_scores, k=live_hyp_num
71
+ )
72
+
73
+ prev_hyp_ids = top_cand_hyp_pos // vocab_size
74
+ hyp_word_ids = top_cand_hyp_pos % vocab_size
75
+
76
+ step += 1
77
+ new_hypotheses = []
78
+ new_hyp_scores = []
79
+
80
+ for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip(
81
+ prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores
82
+ ):
83
+ cand_new_hyp_score = cand_new_hyp_score.detach().item()
84
+ self.hypotheses[prev_hyp_id, step] = hyp_word_id
85
+
86
+ if hyp_word_id == self.stop_w:
87
+ self.completed_hypotheses.append(
88
+ Hypothesis(
89
+ seq_tensor=self.hypotheses[prev_hyp_id, 1:step+1]
90
+ .detach()
91
+ .clone(), # remove START_W at first
92
+ score=cand_new_hyp_score,
93
+ )
94
+ )
95
+ else:
96
+ new_hypotheses.append(self.hypotheses[prev_hyp_id].detach().clone())
97
+ new_hyp_scores.append(cand_new_hyp_score)
98
+
99
+ return new_hypotheses, new_hyp_scores
100
+
101
+ def get_incomplete_inds(self, hyp_word_ids):
102
+ return [ind for ind, next_word in enumerate(hyp_word_ids) if
103
+ next_word != self.stop_w]
104
+
105
+ def get_complete_inds(self, hyp_word_ids, incomplete_inds):
106
+ return list(set(range(len(hyp_word_ids))) - set(incomplete_inds))
107
+
108
+ def set_current_state(self, hypotheses):
109
+ "Set the outputs for the current timestep."
110
+ self.hypotheses = torch.stack(hypotheses, dim=0)
111
+ return
112
+
113
+ def set_current_score(self, hyp_scores):
114
+ "Set the scores for the current timestep."
115
+ self.hyp_scores = torch.tensor(
116
+ hyp_scores, dtype=torch.float, device=self.device
117
+ )
118
+ return
119
+
120
+ def done(self, beam_size):
121
+ return len(self.completed_hypotheses) == beam_size
122
+
123
+ def set_hypothesis(self):
124
+ if len(self.completed_hypotheses) == 0:
125
+ self.completed_hypotheses.append(
126
+ Hypothesis(
127
+ seq_tensor=self.hypotheses[0, 1:].detach().clone(),
128
+ score=self.hyp_scores[0].detach().item(),
129
+ )
130
+ )
131
+ return
HybridViT/helper.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ from PIL import Image
5
+ from typing import Dict
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from PIL import Image
9
+ import cv2
10
+ import math
11
+ import albumentations as alb
12
+ from albumentations.pytorch.transforms import ToTensorV2
13
+ from collections import OrderedDict
14
+ from itertools import repeat
15
+ import collections.abc
16
+
17
+
18
+ # From PyTorch internals
19
+ def _ntuple(n):
20
+ def parse(x):
21
+ if isinstance(x, collections.abc.Iterable):
22
+ return x
23
+ return tuple(repeat(x, n))
24
+ return parse
25
+
26
+ to_3tuple = _ntuple(3)
27
+
28
+ def clean_state_dict(state_dict):
29
+ # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
30
+ cleaned_state_dict = OrderedDict()
31
+ for k, v in state_dict.items():
32
+ name = k[7:] if k.startswith('module.') else k
33
+ cleaned_state_dict[name] = v
34
+ return cleaned_state_dict
35
+
36
+
37
+ def math_transform(mean, std, is_gray: bool):
38
+ test_transform = []
39
+ normalize = [
40
+ alb.CLAHE(clip_limit=2, tile_grid_size=(2, 2), always_apply=True),
41
+ alb.Normalize(to_3tuple(mean), to_3tuple(std)),
42
+ ToTensorV2()
43
+ ]
44
+ if is_gray:
45
+ test_transform += [alb.ToGray(always_apply=True)]
46
+ test_transform += normalize
47
+
48
+ test_transform = alb.Compose([*test_transform])
49
+ return test_transform
50
+
51
+
52
+ def pad(img: Image.Image, divable=32):
53
+ """Pad an Image to the next full divisible value of `divable`. Also normalizes the PIL.image and invert if needed.
54
+
55
+ Args:
56
+ img (PIL.Image): input PIL.image
57
+ divable (int, optional): . Defaults to 32.
58
+
59
+ Returns:
60
+ PIL.Image
61
+ """
62
+ data = np.array(img.convert('LA'))
63
+
64
+ data = (data-data.min())/(data.max()-data.min())*255
65
+ if data[..., 0].mean() > 128:
66
+ gray = 255*(data[..., 0] < 128).astype(np.uint8) # To invert the text to white
67
+ else:
68
+ gray = 255*(data[..., 0] > 128).astype(np.uint8)
69
+ data[..., 0] = 255-data[..., 0]
70
+
71
+ coords = cv2.findNonZero(gray) # Find all non-zero points (text)
72
+ a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
73
+ rect = data[b:b+h, a:a+w]
74
+
75
+ if rect[..., -1].var() == 0:
76
+ im = Image.fromarray((rect[..., 0]).astype(np.uint8)).convert('L')
77
+ else:
78
+ im = Image.fromarray((255-rect[..., -1]).astype(np.uint8)).convert('L')
79
+ dims = []
80
+
81
+ for x in [w, h]:
82
+ div, mod = divmod(x, divable)
83
+ dims.append(divable*(div + (1 if mod > 0 else 0)))
84
+
85
+ padded = Image.new('L', dims, 255)
86
+ padded.paste(im, im.getbbox())
87
+
88
+ return padded
89
+
90
+ def get_divisible_size(ori_h, ori_w, max_dimension=None, scale_factor=32):
91
+ new_h, new_w = ori_h, ori_w
92
+ if ori_h % scale_factor:
93
+ new_h = math.ceil(ori_h/scale_factor)*scale_factor
94
+ if new_h > max_dimension[0]:
95
+ new_h = math.floor(ori_h/scale_factor)*scale_factor
96
+
97
+ if ori_w % scale_factor:
98
+ new_w = math.ceil(ori_w/scale_factor)*scale_factor
99
+ if new_w > max_dimension[1]:
100
+ new_w = math.floor(ori_w/scale_factor)*scale_factor
101
+
102
+ return int(new_h),int(new_w)
103
+
104
+ def minmax_size(img, max_dimensions=None, min_dimensions=None, is_gray=True):
105
+ if max_dimensions is not None:
106
+ ratios = [a/b for a, b in zip(list(img.size)[::-1], max_dimensions)]
107
+ if any([r > 1 for r in ratios]):
108
+ size = np.array(img.size)/max(ratios)
109
+ new_h, new_w = get_divisible_size(size[1], size[0], max_dimensions)
110
+ img = img.resize((new_w, new_h), Image.LANCZOS)
111
+
112
+ if min_dimensions is not None:
113
+ ratios = [a/b for a, b in zip(list(img.size)[::-1], min_dimensions)]
114
+ if any([r < 1 for r in ratios]):
115
+ new_h, new_w = img.size[1] / min(ratios), img.size[0] / min(ratios)
116
+ new_h, new_w = get_divisible_size(new_h, new_w, max_dimensions)
117
+ if is_gray:
118
+ MODE = 'L'
119
+ BACKGROUND = 255
120
+ padded_im = Image.new(MODE, (new_w, new_h), BACKGROUND)
121
+ padded_im.paste(img, img.getbbox())
122
+ img = padded_im
123
+
124
+ return img
125
+
126
+ def resize(resizer, img: Image.Image, opt: Dict):
127
+ # for math recognition problem image alway in grayscale mode
128
+ img = img.convert('L')
129
+ assert isinstance(opt, Dict)
130
+ assert "imgH" in opt
131
+ assert "imgW" in opt
132
+ expected_H = opt['imgH']
133
+
134
+ if expected_H is None:
135
+ max_dimensions = opt['max_dimension'] #can be bigger than max dim in training set
136
+ min_dimensions = opt['min_dimension']
137
+ #equal to min dim in trainign set
138
+ test_transform = math_transform(opt['mean'], opt['std'], not opt['rgb'])
139
+ try:
140
+ new_img = minmax_size(pad(img) if opt['pad'] else img, max_dimensions, min_dimensions, not opt['rgb'])
141
+
142
+ if not resizer:
143
+ new_img = np.asarray(new_img.convert('RGB')).astype('uint8')
144
+ new_img = test_transform(image=new_img)['image']
145
+ if not opt['rgb']: new_img = new_img[:1]
146
+ new_img = new_img.unsqueeze(0)
147
+ new_img = new_img.float()
148
+ else:
149
+ with torch.no_grad():
150
+ input_image = new_img.convert('RGB').copy()
151
+ r, w, h = 1, input_image.size[0], input_image.size[1]
152
+ for i in range(20):
153
+ h = int(h * r)
154
+ new_img = pad(minmax_size(input_image.resize((w, h), Image.BILINEAR if r > 1 else Image.LANCZOS),
155
+ max_dimensions,
156
+ min_dimensions,
157
+ not opt['rgb']
158
+ ))
159
+ t = test_transform(image=np.array(new_img.convert('RGB')).astype('uint8'))['image']
160
+ if not opt['rgb']: t = t[:1]
161
+ t = t.unsqueeze(0)
162
+ t = t.float()
163
+ w = (resizer(t.to(opt['device'])).argmax(-1).item()+1)*opt['min_dimension'][1]
164
+
165
+ if (w == new_img.size[0]):
166
+ break
167
+
168
+ r = w/new_img.size[0]
169
+
170
+ new_img = t
171
+ except ValueError as e:
172
+ print('Error:', e)
173
+ new_img = np.asarray(img.convert('RGB')).astype('uint8')
174
+ assert len(new_img.shape) == 3 and new_img.shape[2] == 3
175
+ new_img = test_transform(image=new_img)['image']
176
+ if not opt['rgb']: new_img = new_img[:1]
177
+ new_img = new_img.unsqueeze(0)
178
+ h, w = new_img.shape[2:]
179
+ new_img = F.pad(new_img, (0, max_dimensions[1]-w, 0, max_dimensions[0]-h), value=1)
180
+
181
+ assert len(new_img.shape) == 4, f'{new_img.shape}'
182
+ return new_img
HybridViT/module/component/common/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .conv import *
2
+ from .droppath import *
3
+ from .gated_sum import *
4
+ from .maxout import *
5
+ from .postional_encoding import *
6
+ from .mae_posembed import *
HybridViT/module/component/common/conv.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional
5
+ import warnings
6
+
7
+ __all__ = ['ConvMLP', 'ConvModule']
8
+
9
+ class LayerNorm2d(nn.LayerNorm):
10
+ """ LayerNorm for channels of '2D' spatial BCHW tensors """
11
+ def __init__(self, num_channels):
12
+ super().__init__(num_channels)
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return F.layer_norm(
16
+ x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
17
+
18
+ class DepthwiseSeparableConv2d(nn.Module):
19
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
20
+ super(DepthwiseSeparableConv2d, self).__init__()
21
+ self.depthwise = nn.Conv2d(
22
+ in_channels,
23
+ in_channels,
24
+ kernel_size=kernel_size,
25
+ dilation=dilation,
26
+ padding=padding,
27
+ stride=stride,
28
+ bias=bias,
29
+ groups=in_channels,
30
+ )
31
+ self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups, bias=bias)
32
+
33
+ def forward(self, x):
34
+ out = self.depthwise(x)
35
+ out = self.pointwise(out)
36
+ return out
37
+
38
+ class ConvMLP(nn.Module):
39
+ def __init__(self, in_channels, out_channels=None, hidden_channels=None, drop=0.25):
40
+ super().__init__()
41
+ out_channels = in_channels or out_channels
42
+ hidden_channels = in_channels or hidden_channels
43
+ self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=True)
44
+ self.norm = LayerNorm2d(hidden_channels)
45
+ self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=True)
46
+ self.act = nn.ReLU()
47
+ self.drop = nn.Dropout(drop)
48
+
49
+ def forward(self, x):
50
+ x = self.fc1(x)
51
+ x = self.norm(x)
52
+ x = self.act(x)
53
+ x = self.drop(x)
54
+ x = self.fc2(x)
55
+ return x
56
+
57
+ class ConvModule(nn.Module):
58
+ def __init__(self,
59
+ in_channels,
60
+ out_channels,
61
+ kernel_size,
62
+ stride=1,
63
+ padding=0,
64
+ dilation=1,
65
+ groups=1,
66
+ bias='auto',
67
+ conv_layer:Optional[nn.Module]=nn.Conv2d,
68
+ norm_layer:Optional[nn.Module]=nn.BatchNorm2d,
69
+ act_layer:Optional[nn.Module]=nn.ReLU,
70
+ inplace=True,
71
+ with_spectral_norm=False,
72
+ padding_mode='zeros',
73
+ order=('conv', 'norm', 'act')
74
+ ):
75
+ official_padding_mode = ['zeros', 'circular']
76
+ nonofficial_padding_mode = dict(zero=nn.ZeroPad2d, reflect=nn.ReflectionPad2d, replicate=nn.ReplicationPad2d)
77
+ self.with_spectral_norm = with_spectral_norm
78
+ self.with_explicit_padding = padding_mode not in official_padding_mode
79
+ self.order = order
80
+ assert isinstance(self.order, tuple) and len(self.order) == 3
81
+ assert set(order) == set(['conv', 'norm', 'act'])
82
+
83
+ self.with_norm = norm_layer is not None
84
+ self.with_act = act_layer is not None
85
+
86
+ if bias == 'auto':
87
+ bias = not self.with_norm
88
+ self.with_bias = bias
89
+
90
+ if self.with_norm and self.with_bias:
91
+ warnings.warn('ConvModule has norm and bias at the same time')
92
+
93
+ if self.with_explicit_padding:
94
+ assert padding_mode in list(nonofficial_padding_mode), "Not implemented padding algorithm"
95
+ self.padding_layer = nonofficial_padding_mode[padding_mode]
96
+
97
+ # reset padding to 0 for conv module
98
+ conv_padding = 0 if self.with_explicit_padding else padding
99
+
100
+ self.conv = conv_layer(
101
+ in_channels,
102
+ out_channels,
103
+ kernel_size,
104
+ stride=stride,
105
+ padding=conv_padding,
106
+ dilation=dilation,
107
+ groups=groups,
108
+ bias=bias
109
+ )
110
+ self.in_channels = self.conv.in_channels
111
+ self.out_channels = self.conv.out_channels
112
+ self.kernel_size = self.conv.kernel_size
113
+ self.stride = self.conv.stride
114
+ self.padding = padding
115
+ self.dilation = self.conv.dilation
116
+ self.transposed = self.conv.transposed
117
+ self.output_padding = self.conv.output_padding
118
+ self.groups = self.conv.groups
119
+
120
+ if self.with_spectral_norm:
121
+ self.conv = nn.utils.spectral_norm(self.conv)
122
+
123
+ # build normalization layers
124
+ if self.with_norm:
125
+ # norm layer is after conv layer
126
+ if order.index('norm') > order.index('conv'):
127
+ norm_channels = out_channels
128
+ else:
129
+ norm_channels = in_channels
130
+ self.norm = norm_layer(norm_channels)
131
+
132
+ if self.with_act:
133
+ if act_layer not in [nn.Tanh, nn.PReLU, nn.Sigmoid]:
134
+ self.act = act_layer()
135
+ else:
136
+ self.act = act_layer(inplace=inplace)
137
+
138
+ def forward(self, x, activate=True, norm=True):
139
+ for layer in self.order:
140
+ if layer == 'conv':
141
+ if self.with_explicit_padding:
142
+ x = self.padding_layer(x)
143
+ x = self.conv(x)
144
+ elif layer == 'norm' and norm and self.with_norm:
145
+ x = self.norm(x)
146
+ elif layer == 'act' and activate and self.with_act:
147
+ x = self.act(x)
148
+ return x
HybridViT/module/component/common/droppath.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ __all__ = ['DropPath']
4
+
5
+
6
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
7
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
8
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
9
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
10
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
11
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
12
+ 'survival rate' as the argument.
13
+ """
14
+ if drop_prob == 0. or not training:
15
+ return x
16
+ keep_prob = 1 - drop_prob
17
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
18
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
19
+ if keep_prob > 0.0 and scale_by_keep:
20
+ random_tensor.div_(keep_prob)
21
+ return x * random_tensor
22
+
23
+
24
+ class DropPath(nn.Module):
25
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
26
+ """
27
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
28
+ super(DropPath, self).__init__()
29
+ self.drop_prob = drop_prob
30
+ self.scale_by_keep = scale_by_keep
31
+
32
+ def forward(self, x):
33
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
34
+
35
+ def extra_repr(self):
36
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
HybridViT/module/component/common/gated_sum.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class GatedSum(torch.nn.Module):
5
+ """
6
+ This `Module` represents a gated sum of two tensors `a` and `b`. Specifically:
7
+ ```
8
+ f = activation(W [a; b])
9
+ out = f * a + (1 - f) * b
10
+ ```
11
+ # Parameters
12
+ input_dim : `int`, required
13
+ The dimensionality of the input. We assume the input have shape `(..., input_dim)`.
14
+ activation : `Activation`, optional (default = `torch.nn.Sigmoid()`)
15
+ The activation function to use.
16
+ """
17
+
18
+ def __init__(self, input_dim: int, activation = torch.nn.Sigmoid()) -> None:
19
+ super().__init__()
20
+ self.input_dim = input_dim
21
+ self._gate = torch.nn.Linear(input_dim * 2, 1)
22
+ self._activation = activation
23
+
24
+ def get_input_dim(self):
25
+ return self.input_dim
26
+
27
+ def get_output_dim(self):
28
+ return self.input_dim
29
+
30
+ def forward(self, input_a: torch.Tensor, input_b: torch.Tensor) -> torch.Tensor:
31
+ if input_a.size() != input_b.size():
32
+ raise ValueError("The input must have the same size.")
33
+ if input_a.size(-1) != self.input_dim:
34
+ raise ValueError("Input size must match `input_dim`.")
35
+ gate_value = self._activation(self._gate(torch.cat([input_a, input_b], -1)))
36
+ return gate_value * input_a + (1 - gate_value) * input_b
HybridViT/module/component/common/mae_posembed.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+ # --------------------------------------------------------
12
+ # 2D sine-cosine position embedding
13
+ # References:
14
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
15
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
16
+ # --------------------------------------------------------
17
+
18
+ def get_2d_sincos_pos_embed(embed_dim, grid_size_H, grid_size_W, cls_token=False):
19
+ """
20
+ grid_size: int of the grid height and width
21
+ return:
22
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
23
+ """
24
+ grid_h = np.arange(grid_size_H, dtype=np.float32)
25
+ grid_w = np.arange(grid_size_W, dtype=np.float32)
26
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
27
+ grid = np.stack(grid, axis=0)
28
+
29
+ grid = grid.reshape([2, 1, grid_size_H, grid_size_W])
30
+
31
+ print('new grid shape', grid.shape)
32
+
33
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
34
+ if cls_token:
35
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
36
+ return pos_embed
37
+
38
+
39
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
40
+ assert embed_dim % 2 == 0
41
+
42
+ # use half of dimensions to encode grid_h
43
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
44
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
45
+
46
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
47
+ return emb
48
+
49
+
50
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
51
+ """
52
+ embed_dim: output dimension for each position
53
+ pos: a list of positions to be encoded: size (M,)
54
+ out: (M, D)
55
+ """
56
+ assert embed_dim % 2 == 0
57
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
58
+ omega /= embed_dim / 2.
59
+ omega = 1. / 10000**omega # (D/2,)
60
+
61
+ pos = pos.reshape(-1) # (M,)
62
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
63
+
64
+ emb_sin = np.sin(out) # (M, D/2)
65
+ emb_cos = np.cos(out) # (M, D/2)
66
+
67
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
68
+ return emb
69
+
70
+ if __name__ == '__main__':
71
+ pos_embed = get_2d_sincos_pos_embed(256, 800, 800, True)
72
+ print(pos_embed.shape)
HybridViT/module/component/common/maxout.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ class Maxout(nn.Module):
4
+ """
5
+ Maxout makes pools from the last dimension and keeps only the maximum value from
6
+ each pool.
7
+ """
8
+
9
+ def __init__(self, pool_size):
10
+ """
11
+ Args:
12
+ pool_size (int): Number of elements per pool
13
+ """
14
+ super(Maxout, self).__init__()
15
+ self.pool_size = pool_size
16
+
17
+ def forward(self, x):
18
+ [*shape, last] = x.size()
19
+ out = x.view(*shape, last // self.pool_size, self.pool_size)
20
+ out, _ = out.max(-1)
21
+ return out
22
+
HybridViT/module/component/common/postional_encoding.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Tuple
7
+ from torch import Tensor
8
+ from torch.nn import functional as F
9
+
10
+
11
+ class Adaptive2DPositionalEncoding(nn.Module):
12
+ """Implement Adaptive 2D positional encoder for SATRN, see
13
+ `SATRN <https://arxiv.org/abs/1910.04396>`_
14
+ Modified from https://github.com/Media-Smart/vedastr
15
+ Licensed under the Apache License, Version 2.0 (the "License");
16
+ Args:
17
+ d_hid (int): Dimensions of hidden layer.
18
+ n_height (int): Max height of the 2D feature output.
19
+ n_width (int): Max width of the 2D feature output.
20
+ dropout (int): Size of hidden layers of the model.
21
+ """
22
+
23
+ def __init__(self,
24
+ d_hid=512,
25
+ n_height=100,
26
+ n_width=100,
27
+ dropout=0.1,
28
+ ):
29
+ super().__init__()
30
+
31
+ h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid)
32
+ h_position_encoder = h_position_encoder.transpose(0, 1)
33
+ h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1)
34
+
35
+ w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid)
36
+ w_position_encoder = w_position_encoder.transpose(0, 1)
37
+ w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width)
38
+
39
+ self.register_buffer('h_position_encoder', h_position_encoder)
40
+ self.register_buffer('w_position_encoder', w_position_encoder)
41
+
42
+ self.h_scale = self.scale_factor_generate(d_hid)
43
+ self.w_scale = self.scale_factor_generate(d_hid)
44
+ self.pool = nn.AdaptiveAvgPool2d(1)
45
+ self.dropout = nn.Dropout(p=dropout)
46
+
47
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
48
+ """Sinusoid position encoding table."""
49
+ denominator = torch.Tensor([
50
+ 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
51
+ for hid_j in range(d_hid)
52
+ ])
53
+ denominator = denominator.view(1, -1)
54
+ pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
55
+ sinusoid_table = pos_tensor * denominator
56
+ sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
57
+ sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
58
+
59
+ return sinusoid_table
60
+
61
+ def scale_factor_generate(self, d_hid):
62
+ scale_factor = nn.Sequential(
63
+ nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True),
64
+ nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid())
65
+
66
+ return scale_factor
67
+
68
+ def init_weight(self):
69
+ for m in self.modules():
70
+ if isinstance(m, nn.Conv2d):
71
+ nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('ReLU'))
72
+
73
+ def forward(self, x):
74
+ b, c, h, w = x.size()
75
+
76
+ avg_pool = self.pool(x)
77
+
78
+ h_pos_encoding = \
79
+ self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
80
+ w_pos_encoding = \
81
+ self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
82
+
83
+ out = x + h_pos_encoding + w_pos_encoding
84
+
85
+ out = self.dropout(out)
86
+
87
+ return out
88
+
89
+ class PositionalEncoding2D(nn.Module):
90
+ """2-D positional encodings for the feature maps produced by the encoder.
91
+ Following https://arxiv.org/abs/2103.06450 by Sumeet Singh.
92
+ Reference:
93
+ https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2021-labs/blob/main/lab9/text_recognizer/models/transformer_util.py
94
+ """
95
+
96
+ def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000) -> None:
97
+ super().__init__()
98
+ self.d_model = d_model
99
+ assert d_model % 2 == 0, f"Embedding depth {d_model} is not even"
100
+ pe = self.make_pe(d_model, max_h, max_w) # (d_model, max_h, max_w)
101
+ self.register_buffer("pe", pe)
102
+
103
+ @staticmethod
104
+ def make_pe(d_model: int, max_h: int, max_w: int) -> Tensor:
105
+ """Compute positional encoding."""
106
+ pe_h = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2)
107
+ pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w)
108
+
109
+ pe_w = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2)
110
+ pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w)
111
+
112
+ pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w)
113
+ return pe
114
+
115
+ def forward(self, x: Tensor) -> Tensor:
116
+ """Forward pass.
117
+ Args:
118
+ x: (B, d_model, H, W)
119
+ Returns:
120
+ (B, d_model, H, W)
121
+ """
122
+ assert x.shape[1] == self.pe.shape[0] # type: ignore
123
+ x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore
124
+ return x
125
+
126
+
127
+ class PositionalEncoding1D(nn.Module):
128
+ """Classic Attention-is-all-you-need positional encoding."""
129
+
130
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000) -> None:
131
+ super().__init__()
132
+ self.dropout = nn.Dropout(p=dropout)
133
+ pe = self.make_pe(d_model, max_len) # (max_len, 1, d_model)
134
+ self.register_buffer("pe", pe)
135
+
136
+ @staticmethod
137
+ def make_pe(d_model: int, max_len: int) -> Tensor:
138
+ """Compute positional encoding."""
139
+ pe = torch.zeros(max_len, d_model)
140
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
141
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
142
+ pe[:, 0::2] = torch.sin(position * div_term)
143
+ pe[:, 1::2] = torch.cos(position * div_term)
144
+ pe = pe.unsqueeze(1)
145
+ return pe
146
+
147
+ def forward(self, x: Tensor) -> Tensor:
148
+ """Forward pass.
149
+ Args:
150
+ x: (S, B, d_model)
151
+ Returns:
152
+ (S, B, d_model)
153
+ """
154
+ assert x.shape[2] == self.pe.shape[2] # type: ignore
155
+ x = x + self.pe[: x.size(0)] # type: ignore
156
+ return self.dropout(x)
157
+
158
+ Size_ = Tuple[int, int]
159
+
160
+ class PosConv(nn.Module):
161
+ # PEG from https://arxiv.org/abs/2102.10882
162
+ def __init__(self, in_chans, embed_dim=768, stride=1):
163
+ super(PosConv, self).__init__()
164
+ self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), )
165
+ self.stride = stride
166
+
167
+ def forward(self, x, size: Size_):
168
+ B, N, C = x.shape
169
+ cls_token, feat_token = x[:, 0], x[:, 1:]
170
+ cnn_feat_token = feat_token.transpose(1, 2).view(B, C, *size)
171
+ x = self.proj(cnn_feat_token)
172
+ if self.stride == 1:
173
+ x += cnn_feat_token
174
+ x = x.flatten(2).transpose(1, 2)
175
+ x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
176
+ return x
177
+
178
+ def no_weight_decay(self):
179
+ return ['proj.%d.weight' % i for i in range(4)]
180
+
181
+ class PosConv1D(nn.Module):
182
+ # PEG from https://arxiv.org/abs/2102.10882
183
+ def __init__(self, in_chans, embed_dim=768, stride=1):
184
+ super(PosConv1D, self).__init__()
185
+ self.proj = nn.Sequential(nn.Conv1d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), )
186
+ self.stride = stride
187
+
188
+ def forward(self, x, size: int):
189
+ B, N, C = x.shape
190
+ cls_token, feat_token = x[:, 0], x[:, 1:]
191
+ cnn_feat_token = feat_token.transpose(1, 2).view(B, C, size)
192
+ x = self.proj(cnn_feat_token)
193
+ if self.stride == 1:
194
+ x += cnn_feat_token
195
+ x = x.transpose(1, 2)
196
+ x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
197
+ return x
198
+
199
+ def no_weight_decay(self):
200
+ return ['proj.%d.weight' % i for i in range(4)]
201
+
202
+ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), old_grid_shape=()):
203
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
204
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
205
+
206
+ print('Resized position embedding: %s to %s'%(posemb.shape, posemb_new.shape))
207
+ ntok_new = posemb_new.shape[1]
208
+
209
+ if num_tokens:
210
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
211
+ ntok_new -= num_tokens
212
+ else:
213
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
214
+
215
+ if not len(gs_new): # backwards compatibility
216
+ gs_new = [int(math.sqrt(ntok_new))] * 2
217
+
218
+ assert len(gs_new) >= 2
219
+
220
+ print('Position embedding grid-size from %s to %s'%(old_grid_shape, gs_new))
221
+ posemb_grid = posemb_grid.reshape(1, old_grid_shape[0], old_grid_shape[1], -1).permute(0, 3, 1, 2)
222
+ posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
223
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
224
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
225
+
226
+ return posemb
HybridViT/module/component/feature_extractor/__init__.py ADDED
File without changes
HybridViT/module/component/feature_extractor/addon_module/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .aspp import *
2
+ from .visual_attention import *
HybridViT/module/component/feature_extractor/addon_module/aspp.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['ASPP']
6
+
7
+ class ASPPModule(nn.Module):
8
+ def __init__(self, inplanes, planes, kernel_size, padding, dilation):
9
+ super(ASPPModule, self).__init__()
10
+ self.atrous_conv = nn.Conv2d(
11
+ inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False
12
+ )
13
+ self.relu = nn.ReLU(inplace=True)
14
+
15
+ def forward(self, x): # skipcq: PYL-W0221
16
+ x = self.atrous_conv(x)
17
+ x = self.relu(x)
18
+ return x
19
+
20
+ class ASPP(nn.Module):
21
+ def __init__(self, inplanes: int, output_stride: int, output_features: int, dropout=0.5):
22
+ super(ASPP, self).__init__()
23
+
24
+ if output_stride == 32:
25
+ dilations = [1, 3, 6, 9]
26
+ elif output_stride == 16:
27
+ dilations = [1, 6, 12, 18]
28
+ elif output_stride == 8:
29
+ dilations = [1, 12, 24, 36]
30
+ else:
31
+ raise NotImplementedError
32
+
33
+ self.aspp1 = ASPPModule(inplanes, output_features, 1, padding=0, dilation=dilations[0])
34
+ self.aspp2 = ASPPModule(inplanes, output_features, 3, padding=dilations[1], dilation=dilations[1])
35
+ self.aspp3 = ASPPModule(inplanes, output_features, 3, padding=dilations[2], dilation=dilations[2])
36
+ self.aspp4 = ASPPModule(inplanes, output_features, 3, padding=dilations[3], dilation=dilations[3])
37
+
38
+ self.global_avg_pool = nn.Sequential(
39
+ nn.AdaptiveAvgPool2d((1, 1)),
40
+ nn.Conv2d(inplanes, output_features, 1, stride=1, bias=False),
41
+ nn.ReLU(inplace=True),
42
+ )
43
+ self.conv1 = nn.Conv2d(output_features * 5, output_features, 1, bias=False)
44
+ self.relu1 = nn.ReLU(inplace=True)
45
+ self.dropout = nn.Dropout(dropout)
46
+
47
+ def forward(self, x): # skipcq: PYL-W0221
48
+ x1 = self.aspp1(x)
49
+ x2 = self.aspp2(x)
50
+ x3 = self.aspp3(x)
51
+ x4 = self.aspp4(x)
52
+ x5 = self.global_avg_pool(x)
53
+ x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=False)
54
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
55
+
56
+ x = self.conv1(x)
57
+ x = self.relu1(x)
58
+
59
+ return self.dropout(x)
HybridViT/module/component/feature_extractor/addon_module/visual_attention.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from functools import reduce
5
+
6
+ __all__ = ['Adaptive_Global_Model', 'GlobalContext', 'SELayer', 'SKBlock', 'CBAM']
7
+
8
+
9
+ def constant_init(module, val, bias=0):
10
+ if hasattr(module, 'weight') and module.weight is not None:
11
+ nn.init.constant_(module.weight, val)
12
+ if hasattr(module, 'bias') and module.bias is not None:
13
+ nn.init.constant_(module.bias, bias)
14
+
15
+
16
+ def xavier_init(module, gain=1, bias=0, distribution='normal'):
17
+ assert distribution in ['uniform', 'normal']
18
+ if distribution == 'uniform':
19
+ nn.init.xavier_uniform_(module.weight, gain=gain)
20
+ else:
21
+ nn.init.xavier_normal_(module.weight, gain=gain)
22
+ if hasattr(module, 'bias') and module.bias is not None:
23
+ nn.init.constant_(module.bias, bias)
24
+
25
+
26
+ def normal_init(module, mean=0, std=1, bias=0):
27
+ nn.init.normal_(module.weight, mean, std)
28
+ if hasattr(module, 'bias') and module.bias is not None:
29
+ nn.init.constant_(module.bias, bias)
30
+
31
+
32
+ def uniform_init(module, a=0, b=1, bias=0):
33
+ nn.init.uniform_(module.weight, a, b)
34
+ if hasattr(module, 'bias') and module.bias is not None:
35
+ nn.init.constant_(module.bias, bias)
36
+
37
+
38
+ def kaiming_init(module,
39
+ a=0,
40
+ mode='fan_out',
41
+ nonlinearity='relu',
42
+ bias=0,
43
+ distribution='normal'):
44
+ assert distribution in ['uniform', 'normal']
45
+ if distribution == 'uniform':
46
+ nn.init.kaiming_uniform_(
47
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
48
+ else:
49
+ nn.init.kaiming_normal_(
50
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
51
+ if hasattr(module, 'bias') and module.bias is not None:
52
+ nn.init.constant_(module.bias, bias)
53
+
54
+
55
+ def last_zero_init(m):
56
+ if isinstance(m, nn.Sequential):
57
+ constant_init(m[-1], val=0)
58
+ else:
59
+ constant_init(m, val=0)
60
+
61
+
62
+ class SELayer(nn.Module):
63
+ def __init__(self, channel, reduction=16, dropout=0.1):
64
+ super(SELayer, self).__init__()
65
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
66
+ self.fc = nn.Sequential(
67
+ nn.Linear(channel, channel // reduction),
68
+ nn.ReLU(inplace=True),
69
+ nn.Dropout(dropout),
70
+ nn.Linear(channel // reduction, channel),
71
+ nn.Sigmoid()
72
+ )
73
+
74
+ def forward(self, x):
75
+ b, c, _, _ = x.size()
76
+ y = self.avg_pool(x).view(b, c)
77
+ y = self.fc(y).view(b, c, 1, 1)
78
+ return x + x * y.expand_as(x) # SE-Residual
79
+
80
+
81
+ class BasicConv(nn.Module):
82
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
83
+ bn=True, bias=False):
84
+ super(BasicConv, self).__init__()
85
+ self.out_channels = out_planes
86
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
87
+ dilation=dilation, groups=groups, bias=bias)
88
+ self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
89
+ self.relu = nn.ReLU() if relu else None
90
+
91
+ def forward(self, x):
92
+ x = self.conv(x)
93
+ if self.bn is not None:
94
+ x = self.bn(x)
95
+ if self.relu is not None:
96
+ x = self.relu(x)
97
+ return x
98
+
99
+
100
+ class Flatten(nn.Module):
101
+ def forward(self, x):
102
+ return x.view(x.size(0), -1)
103
+
104
+
105
+ class ChannelGate(nn.Module):
106
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
107
+ super(ChannelGate, self).__init__()
108
+ self.gate_channels = gate_channels
109
+ self.mlp = nn.Sequential(
110
+ Flatten(),
111
+ nn.Linear(gate_channels, gate_channels // reduction_ratio),
112
+ nn.ReLU(),
113
+ nn.Linear(gate_channels // reduction_ratio, gate_channels)
114
+ )
115
+ self.pool_types = pool_types
116
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
117
+ self.maxpool = nn.AdaptiveMaxPool2d(1)
118
+ self.sigmoid = nn.Sigmoid()
119
+
120
+ def forward(self, x):
121
+ channel_att_sum = None
122
+ for pool_type in self.pool_types:
123
+ if pool_type == 'avg':
124
+ avg_pool = self.avgpool(x)
125
+ channel_att_raw = self.mlp(avg_pool)
126
+ elif pool_type == 'max':
127
+ max_pool = self.maxpool(x)
128
+ channel_att_raw = self.mlp(max_pool)
129
+
130
+ if channel_att_sum is None:
131
+ channel_att_sum = channel_att_raw
132
+ else:
133
+ channel_att_sum = channel_att_sum + channel_att_raw
134
+
135
+ scale = self.sigmoid(channel_att_sum).unsqueeze(2) \
136
+ .unsqueeze(3).expand_as(x)
137
+ return x * scale
138
+
139
+
140
+ class ChannelPool(nn.Module):
141
+ def forward(self, x):
142
+ return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1) \
143
+ .unsqueeze(1)), dim=1)
144
+
145
+
146
+ class SpatialGate(nn.Module):
147
+ def __init__(self):
148
+ super(SpatialGate, self).__init__()
149
+ kernel_size = 7
150
+ self.compress = ChannelPool()
151
+ self.spatial = BasicConv(2, 1, kernel_size, stride=1 \
152
+ , padding=(kernel_size - 1) // 2, relu=False)
153
+ self.sigmoid = nn.Sigmoid()
154
+
155
+ def forward(self, x):
156
+ x_compress = self.compress(x)
157
+ x_out = self.spatial(x_compress)
158
+ scale = self.sigmoid(x_out) # broadcasting
159
+ return x * scale
160
+
161
+
162
+ class CBAM(nn.Module):
163
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
164
+ super(CBAM, self).__init__()
165
+ self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
166
+ self.no_spatial = no_spatial
167
+ if not no_spatial:
168
+ self.SpatialGate = SpatialGate()
169
+
170
+ def forward(self, x):
171
+ x_out = self.ChannelGate(x)
172
+ if not self.no_spatial:
173
+ x_out = self.SpatialGate(x_out)
174
+ return x_out
175
+
176
+
177
+ class SKBlock(nn.Module):
178
+ def __init__(self, in_channels, out_channels, stride=1, M=2, r=16, L=32):
179
+ super(SKBlock, self).__init__()
180
+ d = max(in_channels // r, L)
181
+ self.M = M
182
+ self.out_channels = out_channels
183
+ self.conv = nn.ModuleList()
184
+ for i in range(M):
185
+ self.conv.append(nn.Sequential(
186
+ nn.Conv2d(in_channels, out_channels, 3, stride, padding=1 + i, dilation=1 + i, groups=32, bias=False),
187
+ nn.BatchNorm2d(out_channels),
188
+ nn.ReLU(inplace=True)))
189
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
190
+ self.fc1 = nn.Sequential(nn.Conv2d(out_channels, d, 1, bias=False),
191
+ nn.BatchNorm2d(d),
192
+ nn.ReLU(inplace=True))
193
+ self.fc2 = nn.Conv2d(d, out_channels * M, 1, 1, bias=False)
194
+ self.softmax = nn.Softmax(dim=1)
195
+
196
+ def forward(self, input):
197
+ batch_size = input.size(0)
198
+ output = []
199
+ # the part of split
200
+ for i, conv in enumerate(self.conv):
201
+ # print(i,conv(input).size())
202
+ output.append(conv(input))
203
+ # the part of fusion
204
+ U = reduce(lambda x, y: x + y, output)
205
+ s = self.global_pool(U)
206
+ z = self.fc1(s)
207
+ a_b = self.fc2(z)
208
+ a_b = a_b.reshape(batch_size, self.M, self.out_channels, -1)
209
+ a_b = self.softmax(a_b)
210
+ # the part of selection
211
+ a_b = list(a_b.chunk(self.M, dim=1)) # split to a and b
212
+ a_b = list(map(lambda x: x.reshape(batch_size, self.out_channels, 1, 1), a_b))
213
+ V = list(map(lambda x, y: x * y, output, a_b))
214
+ V = reduce(lambda x, y: x + y, V)
215
+ return V
216
+
217
+ def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
218
+ min_value = min_value or divisor
219
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
220
+ # Make sure that round down does not go down by more than 10%.
221
+ if new_v < round_limit * v:
222
+ new_v += divisor
223
+ return new_v
224
+
225
+ class LayerNorm2d(nn.LayerNorm):
226
+ """ LayerNorm for channels of '2D' spatial BCHW tensors """
227
+ def __init__(self, num_channels):
228
+ super().__init__(num_channels)
229
+
230
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
231
+ return F.layer_norm(
232
+ x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
233
+
234
+ class ConvMLP(nn.Module):
235
+ def __init__(self, in_channels, out_channels=None, hidden_channels=None, drop=0.25):
236
+ super().__init__()
237
+ out_channels = in_channels or out_channels
238
+ hidden_channels = in_channels or hidden_channels
239
+ self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=True)
240
+ self.norm = LayerNorm2d(hidden_channels)
241
+ self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=True)
242
+ self.act = nn.ReLU()
243
+ self.drop = nn.Dropout(drop)
244
+
245
+ def forward(self, x):
246
+ x = self.fc1(x)
247
+ x = self.norm(x)
248
+ x = self.act(x)
249
+ x = self.drop(x)
250
+ x = self.fc2(x)
251
+ return x
252
+
253
+ class GlobalContext(nn.Module):
254
+ def __init__(self,
255
+ channel,
256
+ use_attn=True,
257
+ fuse_add=True,
258
+ fuse_scale=False,
259
+ rd_ratio=1./8,
260
+ rd_channels=None
261
+ ):
262
+ super().__init__()
263
+ self.use_attn = use_attn
264
+ self.global_cxt = nn.Conv2d(channel, 1, kernel_size=1, bias=True) if use_attn else nn.AdaptiveAvgPool2d(1)
265
+
266
+ if rd_channels is None:
267
+ rd_channels = make_divisible(channel*rd_ratio, divisor=1, round_limit=0.)
268
+
269
+ if fuse_add:
270
+ self.bottleneck_add = ConvMLP(channel, hidden_channels=rd_channels)
271
+ else:
272
+ self.bottleneck_add = None
273
+ if fuse_scale:
274
+ self.bottleneck_mul = ConvMLP(channel, hidden_channels=rd_channels)
275
+ else:
276
+ self.bottleneck_mul = None
277
+
278
+ self.init_weight()
279
+
280
+ def init_weight(self):
281
+ if self.use_attn:
282
+ nn.init.kaiming_normal_(self.global_cxt.weight, mode='fan_in', nonlinearity='relu')
283
+ if self.bottleneck_add is not None:
284
+ nn.init.zeros_(self.bottleneck_add.fc2.weight)
285
+ if self.bottleneck_mul is not None:
286
+ nn.init.zeros_(self.bottleneck_mul.fc2.weight)
287
+
288
+ def forward(self, x):
289
+ B, C, H, W = x.shape
290
+ if self.use_attn:
291
+ attn = self.global_cxt(x).reshape(B, 1, H*W).squeeze(1)
292
+ attn = F.softmax(attn, dim=-1).unsqueeze(-1) #shape BxH*Wx1
293
+ query = x.reshape(B, C, H*W) #shape BxCxH*W
294
+ glob_cxt = torch.bmm(query, attn).unsqueeze(-1)
295
+ else:
296
+ glob_cxt = self.global_cxt(x)
297
+ assert len(glob_cxt.shape) == 4
298
+
299
+ if self.bottleneck_add is not None:
300
+ x_trans = self.bottleneck_add(glob_cxt)
301
+ x_fuse = x + x_trans
302
+ if self.bottleneck_mul is not None:
303
+ x_trans = F.sigmoid(self.bottleneck_mul(glob_cxt))
304
+ x_fuse = x*x_trans
305
+
306
+ return x_fuse
307
+
308
+
309
+ class Adaptive_Global_Model(nn.Module):
310
+ def __init__(self, inplanes, factor=2, ratio=0.0625, dropout=0.1):
311
+ super(Adaptive_Global_Model, self).__init__()
312
+ # b, w, h, c => gc_block (b, w, h, c) => => b, w, inplanes
313
+ self.embedding = nn.Linear(inplanes * factor, inplanes)
314
+ self.gc_block = GlobalContext(inplanes, ratio=ratio) #
315
+ self.dropout = nn.Dropout(dropout)
316
+
317
+ def forward(self, x):
318
+ x = self.gc_block(x) # BCHW => BCHW
319
+ x = x.permute(0, 3, 1, 2) # BCHW => BWCH
320
+ b, w, _, _ = x.shape
321
+ x = x.contiguous().view(b, w, -1)
322
+ x = self.embedding(x) # B W C
323
+ x = self.dropout(x)
324
+ return x
325
+
HybridViT/module/component/feature_extractor/clova_impl/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .resnet import ResNet_FeatureExtractor
2
+ from .vgg import VGG_FeatureExtractor
HybridViT/module/component/feature_extractor/clova_impl/resnet.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from collections import OrderedDict
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from ..addon_module.visual_attention import GlobalContext
7
+ from .....helper import clean_state_dict
8
+
9
+ class BasicBlock(nn.Module):
10
+ expansion = 1
11
+
12
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
13
+ super(BasicBlock, self).__init__()
14
+ self.conv1 = self._conv3x3(inplanes, planes)
15
+ self.bn1 = nn.BatchNorm2d(planes)
16
+ self.conv2 = self._conv3x3(planes, planes)
17
+ self.bn2 = nn.BatchNorm2d(planes)
18
+ self.relu = nn.ReLU(inplace=True)
19
+ self.downsample = downsample
20
+ self.stride = stride
21
+
22
+ def _conv3x3(self, in_planes, out_planes, stride=1):
23
+ "3x3 convolution with padding"
24
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25
+ padding=1, bias=False)
26
+
27
+ def zero_init_last_bn(self):
28
+ nn.init.zeros_(self.bn2.weight)
29
+
30
+ def forward(self, x):
31
+ residual = x
32
+
33
+ out = self.conv1(x)
34
+ out = self.bn1(out)
35
+ out = self.relu(out)
36
+
37
+ out = self.conv2(out)
38
+ out = self.bn2(out)
39
+
40
+ if self.downsample is not None:
41
+ residual = self.downsample(x)
42
+
43
+ out += residual
44
+ out = self.relu(out)
45
+
46
+ return out
47
+
48
+ class ResNet(nn.Module):
49
+ def __init__(self, input_channel, output_channel, block, layers, with_gcb=True, debug=False, zero_init_last_bn=False):
50
+ super(ResNet, self).__init__()
51
+ self.with_gcb = with_gcb
52
+
53
+ self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
54
+ self.inplanes = int(output_channel / 8)
55
+
56
+ self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
57
+ kernel_size=3, stride=1, padding=1, bias=False)
58
+ self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
59
+
60
+ self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
61
+ kernel_size=3, stride=1, padding=1, bias=False)
62
+ self.bn0_2 = nn.BatchNorm2d(self.inplanes)
63
+ self.relu = nn.ReLU(inplace=True)
64
+
65
+ self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
66
+ self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
67
+ self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
68
+ 0], kernel_size=3, stride=1, padding=1, bias=False)
69
+ self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
70
+
71
+ self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
72
+ self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
73
+ self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
74
+ 1], kernel_size=3, stride=1, padding=1, bias=False)
75
+ self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
76
+
77
+ self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
78
+ self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
79
+ self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
80
+ 2], kernel_size=3, stride=1, padding=1, bias=False)
81
+ self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
82
+
83
+ self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
84
+
85
+ self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
86
+ 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
87
+ self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
88
+
89
+ self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
90
+ 3], kernel_size=2, stride=1, padding=0, bias=False)
91
+ self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
92
+
93
+ self.init_weights(zero_init_last_bn=zero_init_last_bn)
94
+ self.debug = debug
95
+
96
+ def zero_init_last_bn(self):
97
+ nn.init.zeros_(self.bn4_2.weight)
98
+
99
+ def init_weights(self, zero_init_last_bn=True):
100
+ initialized = ['global_cxt', 'bottleneck_add', 'bottleneck_mul']
101
+ for n, m in self.named_modules():
102
+ if any([d in n for d in initialized]):
103
+ continue
104
+ elif isinstance(m, nn.Conv2d):
105
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
106
+ elif isinstance(m, nn.BatchNorm2d):
107
+ nn.init.ones_(m.weight)
108
+ nn.init.zeros_(m.bias)
109
+ if zero_init_last_bn:
110
+ for m in self.modules():
111
+ if hasattr(m, 'zero_init_last_bn'):
112
+ m.zero_init_last_bn()
113
+
114
+ def _make_layer(self, block, planes, blocks, with_gcb=False, stride=1):
115
+ downsample = None
116
+ if stride != 1 or self.inplanes != planes * block.expansion:
117
+ downsample = nn.Sequential(
118
+ nn.Conv2d(self.inplanes, planes * block.expansion,
119
+ kernel_size=1, stride=stride, bias=False),
120
+ nn.BatchNorm2d(planes * block.expansion),
121
+ )
122
+
123
+ layers = []
124
+ layers.append(block(self.inplanes, planes, stride, downsample))
125
+ self.inplanes = planes * block.expansion
126
+
127
+ for i in range(1, blocks):
128
+ layers.append(block(self.inplanes, planes))
129
+
130
+ if self.with_gcb:
131
+ layers.append(GlobalContext(planes))
132
+
133
+ return nn.Sequential(*layers)
134
+
135
+ def forward(self, x):
136
+ if self.debug:
137
+ print('input shape', x.shape)
138
+
139
+ x = self.conv0_1(x)
140
+ x = self.bn0_1(x)
141
+ x = self.relu(x)
142
+
143
+ if self.debug:
144
+ print('conv1 shape', x.shape)
145
+
146
+ x = self.conv0_2(x)
147
+ x = self.bn0_2(x)
148
+ x = self.relu(x)
149
+
150
+ if self.debug:
151
+ print('conv2 shape', x.shape)
152
+
153
+ x = self.maxpool1(x)
154
+
155
+ if self.debug:
156
+ print('pool1 shape', x.shape)
157
+
158
+ x = self.layer1(x)
159
+
160
+ if self.debug:
161
+ print('block1 shape', x.shape)
162
+
163
+ x = self.conv1(x)
164
+ x = self.bn1(x)
165
+ x = self.relu(x)
166
+
167
+ if self.debug:
168
+ print('conv3 shape', x.shape)
169
+
170
+ x = self.maxpool2(x)
171
+
172
+ if self.debug:
173
+ print('pool2 shape', x.shape)
174
+
175
+ x = self.layer2(x)
176
+
177
+ if self.debug:
178
+ print('block2 shape', x.shape)
179
+
180
+ x = self.conv2(x)
181
+ x = self.bn2(x)
182
+ x = self.relu(x)
183
+
184
+ if self.debug:
185
+ print('conv4 shape', x.shape)
186
+
187
+ x = self.maxpool3(x)
188
+
189
+ if self.debug:
190
+ print('pool3 shape', x.shape)
191
+
192
+ x = self.layer3(x)
193
+
194
+ if self.debug:
195
+ print('block3 shape', x.shape)
196
+
197
+ x = self.conv3(x)
198
+ x = self.bn3(x)
199
+ x = self.relu(x)
200
+
201
+ if self.debug:
202
+ print('conv5 shape', x.shape)
203
+
204
+ x = self.layer4(x)
205
+
206
+ if self.debug:
207
+ print('block4 shape', x.shape)
208
+
209
+ x = self.conv4_1(x)
210
+ x = self.bn4_1(x)
211
+ x = self.relu(x)
212
+
213
+ if self.debug:
214
+ print('conv6 shape', x.shape)
215
+
216
+ x = self.conv4_2(x)
217
+ x = self.bn4_2(x)
218
+ x = self.relu(x)
219
+
220
+ if self.debug:
221
+ print('conv7 shape', x.shape)
222
+
223
+ return x
224
+
225
+ class ResNet_FeatureExtractor(nn.Module):
226
+ """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
227
+
228
+ def __init__(self, input_channel=3, output_channel=512, gcb=False, pretrained=False, weight_dir=None, debug=False):
229
+ super(ResNet_FeatureExtractor, self).__init__()
230
+ self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3], gcb, debug)
231
+ self.in_chans = input_channel
232
+ if pretrained:
233
+ assert weight_dir is not None
234
+ self.load_pretrained(weight_dir)
235
+
236
+ def forward(self, input):
237
+ output = self.ConvNet(input)
238
+ return output
239
+
240
+ def load_pretrained(self, weight_dir):
241
+ state_dict: OrderedDict = torch.load(weight_dir)
242
+ cleaned_state_dict = clean_state_dict(state_dict)
243
+ new_state_dict = OrderedDict()
244
+ name: str
245
+ param: torch.FloatTensor
246
+ for name, param in cleaned_state_dict.items():
247
+ if name.startswith('FeatureExtraction'):
248
+ output_name = name.replace('FeatureExtraction.', '')
249
+ if output_name == 'ConvNet.conv0_1.weight':
250
+ print('Old', param.shape)
251
+ new_param = param.repeat(1, self.in_chans, 1, 1)
252
+ print('New', new_param.shape)
253
+ else: new_param = param
254
+ new_state_dict[output_name] = new_param
255
+ print("=> Loading pretrained weight for ResNet backbone")
256
+ self.load_state_dict(new_state_dict)
257
+
258
+ if __name__ == '__main__':
259
+ model = ResNet_FeatureExtractor(input_channel=1, debug=True)
260
+ a = torch.rand(1, 1, 128, 480)
261
+ output = model(a)
262
+ print(output.shape)
HybridViT/module/component/feature_extractor/clova_impl/vgg.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ class VGG_FeatureExtractor(nn.Module):
5
+ """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """
6
+
7
+ def __init__(self, input_channel, output_channel=512):
8
+ super(VGG_FeatureExtractor, self).__init__()
9
+ self.output_channel = [int(output_channel / 8), int(output_channel / 4),
10
+ int(output_channel / 2), output_channel] # [64, 128, 256, 512]
11
+ self.ConvNet = nn.Sequential(
12
+ nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
13
+ nn.MaxPool2d(2, 2), # 64x16x50
14
+ nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True),
15
+ nn.MaxPool2d(2, 2), # 128x8x25
16
+ nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25
17
+ nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
18
+ nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
19
+ nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False),
20
+ nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25
21
+ nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False),
22
+ nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
23
+ nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
24
+ nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24
25
+
26
+ def forward(self, input):
27
+ return self.ConvNet(input)
HybridViT/module/component/feature_extractor/helpers.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Tuple
3
+
4
+ import torch.nn.functional as F
5
+
6
+ # Calculate symmetric padding for a convolution
7
+ def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
8
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
9
+ return padding
10
+
11
+
12
+ # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
13
+ def get_same_padding(x: int, k: int, s: int, d: int):
14
+ return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
15
+
16
+
17
+ # Can SAME padding for given args be done statically?
18
+ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
19
+ return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
20
+
21
+
22
+ # Dynamically pad input x with 'SAME' padding for conv with specified args
23
+ def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
24
+ ih, iw = x.size()[-2:]
25
+ pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
26
+ if pad_h > 0 or pad_w > 0:
27
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
28
+ return x
29
+
30
+
31
+ def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
32
+ dynamic = False
33
+ if isinstance(padding, str):
34
+ # for any string padding, the padding will be calculated for you, one of three ways
35
+ padding = padding.lower()
36
+ if padding == 'same':
37
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
38
+ if is_static_pad(kernel_size, **kwargs):
39
+ # static case, no extra overhead
40
+ padding = get_padding(kernel_size, **kwargs)
41
+ else:
42
+ # dynamic 'SAME' padding, has runtime/GPU memory overhead
43
+ padding = 0
44
+ dynamic = True
45
+ elif padding == 'valid':
46
+ # 'VALID' padding, same as padding=0
47
+ padding = 0
48
+ else:
49
+ # Default to PyTorch style 'same'-ish symmetric padding
50
+ padding = get_padding(kernel_size, **kwargs)
51
+ return padding, dynamic
52
+
53
+
54
+ def adapt_input_conv(in_chans, conv_weight):
55
+ conv_type = conv_weight.dtype
56
+ conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
57
+ O, I, J, K = conv_weight.shape
58
+ if in_chans == 1:
59
+ if I > 3:
60
+ assert conv_weight.shape[1] % 3 == 0
61
+ # For models with space2depth stems
62
+ conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
63
+ conv_weight = conv_weight.sum(dim=2, keepdim=False)
64
+ else:
65
+ conv_weight = conv_weight.sum(dim=1, keepdim=True)
66
+ elif in_chans != 3:
67
+ if I != 3:
68
+ raise NotImplementedError('Weight format not supported by conversion.')
69
+ else:
70
+ # NOTE this strategy should be better than random init, but there could be other combinations of
71
+ # the original RGB input layer weights that'd work better for specific cases.
72
+ repeat = int(math.ceil(in_chans / 3))
73
+ conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
74
+ conv_weight *= (3 / float(in_chans))
75
+ conv_weight = conv_weight.to(conv_type)
76
+ return conv_weight
HybridViT/module/component/feature_extractor/vgg.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ __all__ = ['vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn']
5
+
6
+
7
+ class VGG(nn.Module):
8
+
9
+ def __init__(self, features, num_channel_out=512, init_weights=True):
10
+ super(VGG, self).__init__()
11
+ self.features = features
12
+ self.num_out_features = 512
13
+
14
+ self.lastlayer = nn.Sequential(
15
+ nn.Conv2d(self.num_out_features, num_channel_out, kernel_size=1, stride=1, padding=0, groups=32, bias=False),
16
+ nn.BatchNorm2d(num_channel_out),
17
+ nn.ReLU(inplace=True),
18
+ )
19
+
20
+ if init_weights:
21
+ self._initialize_weights()
22
+
23
+ def forward(self, x):
24
+ x = self.features(x)
25
+ x = self.lastlayer(x)
26
+ return x
27
+
28
+ def _initialize_weights(self):
29
+ for m in self.modules():
30
+ if isinstance(m, nn.Conv2d):
31
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
32
+ if m.bias is not None:
33
+ nn.init.constant_(m.bias, 0)
34
+ elif isinstance(m, nn.BatchNorm2d):
35
+ nn.init.constant_(m.weight, 1)
36
+ nn.init.constant_(m.bias, 0)
37
+ elif isinstance(m, nn.Linear):
38
+ nn.init.normal_(m.weight, 0, 0.01)
39
+ nn.init.constant_(m.bias, 0)
40
+
41
+
42
+ def make_layers(cfg, down_sample=8, batch_norm=False):
43
+ layers = []
44
+ in_channels = 3
45
+ for v in cfg:
46
+ if v == 'M':
47
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
48
+ elif isinstance(v, dict):
49
+ cur_size = v[down_sample]
50
+ layers += [nn.MaxPool2d(kernel_size=cur_size, stride=cur_size)]
51
+ else:
52
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
53
+ if batch_norm:
54
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
55
+ else:
56
+ layers += [conv2d, nn.ReLU(inplace=True)]
57
+ in_channels = v
58
+ return nn.Sequential(*layers)
59
+
60
+
61
+ cfgs = {
62
+ 'A': [64, 'M', 128, 'M', 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, {4: (2, 1), 8: (2, 1)}, 512, 512,
63
+ {4: (2, 1), 8: (2, 1)}],
64
+ 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, {4: (2, 1), 8: (2, 1)}, 512, 512,
65
+ {4: (2, 1), 8: (2, 1)}, ],
66
+ 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, 512, {4: (2, 1), 8: (2, 1)}, 512,
67
+ 512, 512, {4: (2, 1), 8: (2, 1)}, ],
68
+ 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, 512, 512,
69
+ {4: (2, 1), 8: (2, 1)}, 512, 512, 512, 512, {4: (2, 1), 8: (2, 1)}, ],
70
+ }
71
+
72
+
73
+ def _vgg(model_path, cfg, batch_norm, pretrained, progress, num_channel_out, down_sample, **kwargs):
74
+ if pretrained:
75
+ kwargs['init_weights'] = False
76
+ model = VGG(make_layers(cfgs[cfg], down_sample, batch_norm=batch_norm), num_channel_out, **kwargs)
77
+ if model_path and pretrained:
78
+ state_dict = torch.load(model_path)
79
+ model.load_state_dict(state_dict, strict=False)
80
+ return model
81
+
82
+
83
+ def vgg11_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs):
84
+ return _vgg(model_path, 'A', True, pretrained, progress, num_channel_out, down_sample, **kwargs)
85
+
86
+
87
+ def vgg13_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs):
88
+ return _vgg(model_path, 'B', True, pretrained, progress, num_channel_out, down_sample, **kwargs)
89
+
90
+
91
+ def vgg16_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs):
92
+ return _vgg(model_path, 'D', True, pretrained, progress, num_channel_out, down_sample, **kwargs)
93
+
94
+
95
+ def vgg19_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs):
96
+ return _vgg(model_path, 'E', True, pretrained, progress, num_channel_out, down_sample, **kwargs)
HybridViT/module/component/prediction_head/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .seq2seq import Attention
2
+ from .seq2seq_v2 import AttentionV2
3
+ from .tfm import TransformerPrediction
4
+
5
+ __all__ = ['Attention', 'AttentionV2', 'TransformerPrediction']
HybridViT/module/component/prediction_head/addon_module/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .attention1D import *
2
+ from .attention2D import *
3
+ from .position_encoding import *
HybridViT/module/component/prediction_head/addon_module/attention1D.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ torch.autograd.set_detect_anomaly(True)
6
+
7
+ class LuongAttention(nn.Module):
8
+ def __init__(self, input_size, hidden_size, num_embeddings, num_classes, method='dot'):
9
+ super(LuongAttention, self).__init__()
10
+ self.attn = LuongAttentionCell(hidden_size, method)
11
+ self.rnn = nn.LSTMCell(num_embeddings, hidden_size)
12
+ self.hidden_size = hidden_size
13
+ self.generator = nn.Linear(2*hidden_size, num_classes)
14
+
15
+ def forward(self, prev_hidden, batch_H, embed_text):
16
+ hidden = self.rnn(embed_text, prev_hidden)
17
+
18
+ e = self.attn(hidden[0], batch_H)
19
+ # print('Shape e', e.shape)
20
+ alpha = F.softmax(e, dim=1)
21
+ # print('Shape al', alpha.shape)
22
+
23
+ context = torch.bmm(alpha.unsqueeze(1), batch_H).squeeze(1) # batch_size x num_channel
24
+ output = torch.cat([context, hidden[0]], 1) # batch_size x (num_channel + num_embedding)
25
+ output = torch.tanh(output)
26
+ output = self.generator(output)
27
+
28
+ return output, hidden, alpha
29
+
30
+ class LuongAttentionCell(nn.Module):
31
+ def __init__(self, hidden_size, method='dot'):
32
+ super(LuongAttentionCell, self).__init__()
33
+ self.method = method
34
+ self.hidden_size = hidden_size
35
+
36
+ # Defining the layers/weights required depending on alignment scoring method
37
+ if method == "general":
38
+ self.fc = nn.Linear(hidden_size, hidden_size, bias=False)
39
+
40
+ elif method == "concat":
41
+ self.fc = nn.Linear(hidden_size, hidden_size, bias=False)
42
+ self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size))
43
+
44
+ def forward(self, decoder_hidden, encoder_outputs):
45
+ decoder_hidden = decoder_hidden.unsqueeze(1)
46
+ # print('shape', decoder_hidden.shape)
47
+
48
+ if self.method == "dot":
49
+ # For the dot scoring method, no weights or linear layers are involved
50
+ return encoder_outputs.bmm(decoder_hidden.permute(0, 2, 1)).squeeze(-1)
51
+
52
+ elif self.method == "general":
53
+ # For general scoring, decoder hidden state is passed through linear layers to introduce a weight matrix
54
+ out = self.fc(decoder_hidden)
55
+ return encoder_outputs.bmm(out.permute(0, 2 , 1)).squeeze(-1)
56
+
57
+ elif self.method == "concat":
58
+ # For concat scoring, decoder hidden state and encoder outputs are concatenated first
59
+ out = torch.tanh(self.fc(decoder_hidden+encoder_outputs))
60
+ # print('Shape', out.shape)
61
+ return out.bmm(self.weight.unsqueeze(-1).repeat(out.shape[0], 1, 1)).squeeze(-1)
62
+
63
+ class BahdanauAttentionCell(nn.Module):
64
+ def __init__(self, input_dim, hidden_dim):
65
+ super(BahdanauAttentionCell, self).__init__()
66
+ self.i2h = nn.Linear(input_dim, hidden_dim, bias=False)
67
+ self.h2h = nn.Linear(hidden_dim, hidden_dim)
68
+ self.score = nn.Linear(hidden_dim, 1, bias=False)
69
+
70
+ def forward(self, decoder_hidden, encoder_output):
71
+ encoder_proj = self.i2h(encoder_output)
72
+ hidden_proj = self.h2h(decoder_hidden[0]).unsqueeze(1)
73
+ score = self.score(torch.tanh(encoder_proj + hidden_proj))
74
+ return score
75
+
76
+ class BahdanauAttention(nn.Module):
77
+ def __init__(self, input_size=100, hidden_size=256, num_embeddings=10, num_classes=10):
78
+ super(BahdanauAttention, self).__init__()
79
+ self.attn = BahdanauAttentionCell(input_size, hidden_size)
80
+ self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
81
+ self.input_size = input_size
82
+ self.hidden_size = hidden_size
83
+ self.generator = nn.Linear(hidden_size, num_classes)
84
+
85
+ def set_mem(self, prev_attn):
86
+ pass
87
+
88
+ def reset_mem(self):
89
+ pass
90
+
91
+ def forward(self, prev_hidden, batch_H, embed_text):
92
+ # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
93
+ e = self.attn(prev_hidden, batch_H)
94
+ alpha = F.softmax(e, dim=1)
95
+ context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
96
+ concat_context = torch.cat([context, embed_text], 1) # batch_size x (num_channel + num_embedding)
97
+ cur_hidden = self.rnn(concat_context, prev_hidden)
98
+ output = self.generator(cur_hidden[0])
99
+
100
+ return output, cur_hidden, alpha
101
+
102
+ class ConstituentAttentionCell(nn.Module):
103
+ def __init__(self, *args, **kwargs) -> None:
104
+ super().__init__(*args, **kwargs)
105
+
106
+
107
+ class ConstituentCoverageAttentionCell(ConstituentAttentionCell):
108
+ pass
109
+
110
+ class LocationAwareAttentionCell(nn.Module):
111
+ def __init__(self, kernel_size, kernel_dim, hidden_dim, input_dim):
112
+ super().__init__()
113
+ self.loc_conv = nn.Conv1d(1, kernel_dim, kernel_size=2*kernel_size+1, padding=kernel_size, bias=True)
114
+ self.loc_proj = nn.Linear(kernel_dim, hidden_dim)
115
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim)
116
+ self.key_proj = nn.Linear(input_dim, hidden_dim)
117
+ self.score = nn.Linear(hidden_dim, 1)
118
+
119
+ def forward(self, decoder_hidden, encoder_output, last_alignment):
120
+ batch_size, seq_length, hidden_dim = encoder_output.shape[0], encoder_output.shape[1], decoder_hidden[0].shape[1]
121
+
122
+ encoder_proj = self.key_proj(encoder_output)
123
+ hidden_proj = self.query_proj(decoder_hidden[0]).unsqueeze(1)
124
+
125
+ if last_alignment is None:
126
+ last_alignment = decoder_hidden[0].new_zeros(batch_size, seq_length, 1)
127
+
128
+ loc_context = self.loc_conv(last_alignment.permute(0, 2, 1))
129
+ loc_context = loc_context.transpose(1, 2)
130
+ loc_context = self.loc_proj(loc_context)
131
+
132
+ assert len(loc_context.shape) == 3
133
+ assert loc_context.shape[0] == batch_size, f'{loc_context.shape[0]}-{batch_size}'
134
+ assert loc_context.shape[1] == seq_length
135
+ assert loc_context.shape[2] == hidden_dim
136
+
137
+ score = self.score(torch.tanh(
138
+ encoder_proj
139
+ + hidden_proj
140
+ + loc_context
141
+ ))
142
+ return score
143
+
144
+ class CoverageAttention(nn.Module):
145
+ def __init__(self, input_dim, hidden_dim, kernel_size, kernel_dim, temperature=1.0, smoothing=False):
146
+ super().__init__()
147
+ self.smoothing = smoothing
148
+ self.temperature = temperature
149
+ self.prev_attn = None
150
+ self.attn = LocationAwareAttentionCell(kernel_size, kernel_dim, hidden_dim, input_dim)
151
+
152
+ def set_mem(self, prev_attn):
153
+ self.prev_attn = prev_attn
154
+
155
+ def reset_mem(self):
156
+ self.prev_attn = None
157
+
158
+ def forward(self, prev_hidden, batch_H):
159
+ e = self.attn(prev_hidden, batch_H, self.prev_attn)
160
+
161
+ if self.smoothing:
162
+ e = F.sigmoid(e, dim=1)
163
+ alpha = e.div(e.sum(dim=-1).unsqueeze(dim=-1))
164
+ else:
165
+ e = e / self.temperature
166
+ alpha = F.softmax(e, dim=1)
167
+
168
+ context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
169
+
170
+ return context, alpha
171
+
172
+ class LocationAwareAttention(BahdanauAttention):
173
+ def __init__(self, kernel_size, kernel_dim, temperature=1.0, smoothing=False, *args, **kwargs):
174
+ super().__init__(*args, **kwargs)
175
+ self.smoothing = smoothing
176
+ self.temperature = temperature
177
+ self.prev_attn = None
178
+ self.attn = LocationAwareAttentionCell(kernel_size, kernel_dim, self.hidden_size, self.input_size)
179
+
180
+ def set_mem(self, prev_attn):
181
+ self.prev_attn = prev_attn
182
+
183
+ def reset_mem(self):
184
+ self.prev_attn = None
185
+
186
+ def forward(self, prev_hidden, batch_H, embed_text):
187
+ e = self.attn(prev_hidden, batch_H, self.prev_attn)
188
+
189
+ if self.smoothing:
190
+ e = F.sigmoid(e, dim=1)
191
+ alpha = e.div(e.sum(dim=-1).unsqueeze(dim=-1))
192
+ else:
193
+ e = e / self.temperature
194
+ alpha = F.softmax(e, dim=1)
195
+
196
+ context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel , batch_H: batch_sizexseq_lengthxnum_channel, alpha:
197
+ concat_context = torch.cat([context, embed_text], 1) # batch_size x (num_channel + num_embedding)
198
+ cur_hidden = self.rnn(concat_context, prev_hidden)
199
+ output = self.generator(cur_hidden[0])
200
+
201
+ return output, cur_hidden, alpha
202
+
203
+
204
+ # class MaskAttention(nn.Module):
205
+ # def __init__(self):
206
+ # super().__init__()
207
+
208
+ # class CoverageAttentionCell(nn.Module):
209
+ # def __init__(self, )
210
+
211
+ # class CoverageAttention(nn.Module):
212
+ # """
213
+ # http://home.ustc.edu.cn/~xysszjs/paper/PR2017.pdf
214
+ # """
215
+ # def __init__(self, input_size, hidden_size, num_embedding):
216
+ # super().__init__()
217
+
218
+ # def forward(self, prev_hidden, batch_H, embed_text):
HybridViT/module/component/prediction_head/addon_module/attention2D.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .attention1D import LocationAwareAttention
4
+ """
5
+ NOTE :
6
+ """
7
+ class SARAttention(nn.Module):
8
+ def __init__(self,
9
+ input_size,
10
+ attention_size,
11
+ backbone_size,
12
+ output_size,
13
+
14
+ ):
15
+ self.conv1x1_1 = nn.Conv2d(output_size, attention_size, kernel_size=1, stride=1)
16
+ self.conv3x3 = nn.Conv2d(backbone_size, attention_size, kernel_size=3, stride=1, padding=1)
17
+ self.conv1x1_2 = nn.Conv2d(attention_size, 1, kernel_size=1, stride=1)
18
+
19
+ self.rnn_decoder_1 = nn.LSTMCell(input_size, input_size)
20
+ self.rnn_decoder_2 = nn.LSTMCell(input_size, input_size)
21
+
22
+
23
+ def forward(
24
+ self,
25
+ dec_input,
26
+ feature_map,
27
+ holistic_feature,
28
+ hidden_1,
29
+ cell_1,
30
+ hidden_2,
31
+ cell_2
32
+ ):
33
+ _, _, H_feat, W_feat = feature_map.size()
34
+ hidden_1, cell_1 = self.rnn_decoder_1(dec_input, (hidden_1, cell_1))
35
+ hidden_2, cell_2 = self.rnn_decoder_2(hidden_1, (hidden_2, cell_2))
36
+
37
+ hidden_2_tile = hidden_2.view(hidden_2.size(0), hidden_2.size(1), 1, 1)
38
+ attn_query = self.conv1x1_1(hidden_2_tile)
39
+ attn_query = attn_query.expand(-1, -1, H_feat, W_feat)
40
+
41
+ attn_key = self.conv3x3(feature_map)
42
+ attn_weight = torch.tanh(torch.add(attn_query, attn_key, alpha=1))
43
+ attn_weight = self.conv1x1_2(attn_weight) #shape B, 1, H, W
44
+
45
+ #TO DO: apply mask for attention weight
46
+
47
+
48
+ class LocationAwareAttentionCell2D(nn.Module):
49
+ def __init__(self, kernel_size, kernel_dim, hidden_dim, input_dim):
50
+ super().__init__()
51
+ self.loc_conv = nn.Conv2d(1, kernel_dim, kernel_size=2*kernel_size+1, padding=kernel_size, bias=True)
52
+ self.loc_proj = nn.Linear(kernel_dim, hidden_dim)
53
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim)
54
+ self.key_proj = nn.Linear(input_dim, hidden_dim)
55
+ self.score = nn.Linear(hidden_dim, 1)
56
+
57
+ def forward(self, decoder_hidden, encoder_output, last_alignment):
58
+ batch_size, enc_h, enc_w, hidden_dim = encoder_output.shape[0], encoder_output.shape[1], encoder_output.shape[2], decoder_hidden[0].shape[1]
59
+
60
+ encoder_proj = self.key_proj(encoder_output)
61
+ hidden_proj = self.query_proj(decoder_hidden[0]).unsqueeze(1)
62
+
63
+ if last_alignment is None:
64
+ last_alignment = decoder_hidden[0].new_zeros(batch_size, enc_h, enc_w, 1)
65
+
66
+ loc_context = self.loc_conv(last_alignment.permute(0, 2, 1))
67
+ loc_context = loc_context.transpose(1, 2)
68
+ loc_context = self.loc_proj(loc_context)
69
+
70
+ assert len(loc_context.shape) == 3
71
+ assert loc_context.shape[0] == batch_size, f'{loc_context.shape[0]}-{batch_size}'
72
+ assert loc_context.shape[1] == enc_h
73
+ assert loc_context.shape[2] == enc_w
74
+ assert loc_context.shape[3] == hidden_dim
75
+
76
+ loc_context = loc_context.reshape(batch_size, enc_h*enc_w, hidden_dim)
77
+
78
+ score = self.score(torch.tanh(
79
+ encoder_proj
80
+ + hidden_proj
81
+ + loc_context
82
+ ))
83
+ return score
84
+
85
+ class LocationAwareAttention2D(LocationAwareAttention):
86
+ def __init__(self, kernel_size, kernel_dim, temperature=1, smoothing=False, *args, **kwargs):
87
+ super().__init__(kernel_size, kernel_dim, temperature, smoothing, *args, **kwargs)
88
+ self.attn = LocationAwareAttentionCell2D(kernel_size, kernel_dim, self.hidden_size, self.input_size)
HybridViT/module/component/prediction_head/addon_module/position_encoding.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ __all__ = ['WordPosEnc']
5
+
6
+ class WordPosEnc(nn.Module):
7
+ def __init__(
8
+ self, d_model: int = 512, max_len: int = 500, temperature: float = 10000.0
9
+ ) -> None:
10
+ super().__init__()
11
+ pe = torch.zeros(max_len, d_model)
12
+
13
+ position = torch.arange(0, max_len, dtype=torch.float)
14
+ dim_t = torch.arange(0, d_model, 2, dtype=torch.float)
15
+ div_term = 1.0 / (temperature ** (dim_t / d_model))
16
+
17
+ inv_freq = torch.einsum("i, j -> i j", position, div_term)
18
+
19
+ pe[:, 0::2] = inv_freq.sin()
20
+ pe[:, 1::2] = inv_freq.cos()
21
+ self.register_buffer("pe", pe)
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ _, seq_len, _ = x.size()
25
+ emb = self.pe[:seq_len, :]
26
+ x = x + emb[None, :, :]
27
+ return x
HybridViT/module/component/prediction_head/seq2seq.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import repeat
6
+ from ...converter import AttnLabelConverter as ATTN
7
+ from .addon_module import *
8
+
9
+ class Attention(nn.Module):
10
+ def __init__(self,
11
+ kernel_size,
12
+ kernel_dim,
13
+ input_size,
14
+ hidden_size,
15
+ num_classes,
16
+ embed_dim=None,
17
+ attn_type='coverage',
18
+ embed_target=False,
19
+ enc_init=False, #init hidden state of decoder with enc output
20
+ teacher_forcing=1.0,
21
+ droprate=0.1,
22
+ method='concat',
23
+ seqmodel='ViT',
24
+ viz_attn: bool = False,
25
+ device='cuda'
26
+ ):
27
+ super(Attention, self).__init__()
28
+ if embed_dim is None: embed_dim = input_size
29
+ if embed_target:
30
+ self.embedding = nn.Embedding(num_classes, embed_dim, padding_idx=ATTN.START())
31
+
32
+ common = {
33
+ 'input_size': input_size,
34
+ 'hidden_size': hidden_size,
35
+ 'num_embeddings': embed_dim if embed_target else num_classes,
36
+ 'num_classes': num_classes
37
+ }
38
+
39
+ if attn_type == 'luong':
40
+ common['method'] = method
41
+ self.attention_cell = LuongAttention(**common)
42
+ elif attn_type == 'loc_aware':
43
+ self.attention_cell = LocationAwareAttention(kernel_size=kernel_size, kernel_dim=kernel_dim, **common)
44
+ elif attn_type == 'coverage':
45
+ self.attention_cell = LocationAwareAttention(kernel_size=kernel_size, kernel_dim=kernel_dim, **common)
46
+ else:
47
+ self.attention_cell = BahdanauAttention(**common)
48
+
49
+ self.dropout = nn.Dropout(droprate)
50
+ self.embed_target = embed_target
51
+ self.hidden_size = hidden_size
52
+ self.input_size = input_size
53
+ self.num_classes = num_classes
54
+ self.teacher_forcing = teacher_forcing
55
+ self.device = device
56
+ self.attn_type = attn_type
57
+ self.enc_init = enc_init
58
+ self.viz_attn = viz_attn
59
+ self.seqmodel = seqmodel
60
+
61
+ if enc_init: self.init_hidden()
62
+
63
+ def _embed_text(self, input_char):
64
+ return self.embedding(input_char)
65
+
66
+ def _char_to_onehot(self, input_char, onehot_dim=38):
67
+ input_char = input_char.unsqueeze(1)
68
+ batch_size = input_char.size(0)
69
+ one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(self.device)
70
+ one_hot = one_hot.scatter_(1, input_char, 1)
71
+ return one_hot
72
+
73
+ def init_hidden(self):
74
+ self.proj_init_h = nn.Linear(self.input_size, self.hidden_size, bias=True)
75
+ self.proj_init_c = nn.Linear(self.input_size, self.hidden_size, bias=True)
76
+
77
+ def forward_beam(
78
+ self,
79
+ batch_H: torch.Tensor,
80
+ batch_max_length=25,
81
+ beam_size=4,
82
+ ):
83
+ batch_size = batch_H.size(0)
84
+ assert batch_size == 1
85
+ num_steps = batch_max_length + 1
86
+ batch_H = batch_H.squeeze(dim=0)
87
+ batch_H = repeat(batch_H, "s e -> b s e", b = beam_size)
88
+
89
+ if self.enc_init:
90
+ if self.seqmodel == 'BiLSTM':
91
+ init_embedding = batch_H.mean(dim=1)
92
+ else:
93
+ init_embedding = batch_H[:, 0, :]
94
+ h_0 = self.proj_init_h(init_embedding)
95
+ c_0 = self.proj_init_c(init_embedding)
96
+ hidden = (h_0, c_0)
97
+ else:
98
+ hidden = (torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device),
99
+ torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device))
100
+
101
+ if self.attn_type == 'coverage':
102
+ alpha_cum = torch.zeros(beam_size, batch_H.shape[1], 1, dtype=torch.float32, device=self.device)
103
+ self.attention_cell.reset_mem()
104
+
105
+ k_prev_words = torch.LongTensor([[ATTN.START()]] * beam_size).to(self.device)
106
+ seqs = k_prev_words
107
+ targets = k_prev_words.squeeze(dim=-1)
108
+ top_k_scores = torch.zeros(beam_size, 1).to(self.device)
109
+
110
+ if self.viz_attn:
111
+ seqs_alpha = torch.ones(beam_size, 1, batch_H.shape[1]).to(self.device)
112
+
113
+ complete_seqs = list()
114
+ if self.viz_attn:
115
+ complete_seqs_alpha = list()
116
+ complete_seqs_scores = list()
117
+
118
+ for step in range(num_steps):
119
+ embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets)
120
+ output, hidden, alpha = self.attention_cell(hidden, batch_H, embed_text)
121
+ output = self.dropout(output)
122
+ vocab_size = output.shape[1]
123
+
124
+ scores = F.log_softmax(output, dim=-1)
125
+ scores = top_k_scores.expand_as(scores) + scores
126
+ if step == 0:
127
+ top_k_scores, top_k_words = scores[0].topk(beam_size, 0, True, True)
128
+ else:
129
+ top_k_scores, top_k_words = scores.view(-1).topk(beam_size, 0, True, True)
130
+
131
+ prev_word_inds = top_k_words // vocab_size
132
+ next_word_inds = top_k_words % vocab_size
133
+
134
+ seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)
135
+ if self.viz_attn:
136
+ seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].permute(0, 2, 1)],
137
+ dim=1)
138
+
139
+ incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
140
+ next_word != ATTN.END()]
141
+
142
+ complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
143
+
144
+ if len(complete_inds) > 0:
145
+ complete_seqs.extend(seqs[complete_inds].tolist())
146
+ if self.viz_attn:
147
+ complete_seqs_alpha.extend(seqs_alpha[complete_inds])
148
+ complete_seqs_scores.extend(top_k_scores[complete_inds])
149
+
150
+ beam_size = beam_size - len(complete_inds)
151
+ if beam_size == 0:
152
+ break
153
+
154
+ seqs = seqs[incomplete_inds]
155
+ if self.viz_attn:
156
+ seqs_alpha = seqs_alpha[incomplete_inds]
157
+ hidden = hidden[0][prev_word_inds[incomplete_inds]], \
158
+ hidden[1][prev_word_inds[incomplete_inds]]
159
+ batch_H = batch_H[prev_word_inds[incomplete_inds]]
160
+ top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
161
+ targets = next_word_inds[incomplete_inds]
162
+
163
+ if self.attn_type == 'coverage':
164
+ alpha_cum = alpha_cum + alpha
165
+ alpha_cum = alpha_cum[incomplete_inds]
166
+ self.attention_cell.set_mem(alpha_cum)
167
+ elif self.attn_type == 'loc_aware':
168
+ self.attention_cell.set_mem(alpha)
169
+
170
+ if len(complete_inds) == 0:
171
+ seq = seqs[0][1:].tolist()
172
+ seq = torch.LongTensor(seq).unsqueeze(0)
173
+ score = top_k_scores[0]
174
+ if self.viz_attn:
175
+ alphas = seqs_alpha[0][1:, ...]
176
+ return seq, score, alphas
177
+ else:
178
+ return seq, score, None
179
+ else:
180
+ combine_lst = tuple(zip(complete_seqs, complete_seqs_scores))
181
+ best_ind = combine_lst.index(max(combine_lst, key=lambda x: x[1] / len(x[0]))) #https://youtu.be/XXtpJxZBa2c?t=2407
182
+ seq = complete_seqs[best_ind][1:] #not include [GO] token
183
+ seq = torch.LongTensor(seq).unsqueeze(0)
184
+ score = max(complete_seqs_scores)
185
+
186
+ if self.viz_attn:
187
+ alphas = complete_seqs_alpha[best_ind][1:, ...]
188
+ return seq, score, alphas
189
+ else:
190
+ return seq, score, None
191
+
192
+ def forward_greedy(self, batch_H, text, is_train=True, is_test=False, batch_max_length=25):
193
+ batch_size = batch_H.size(0)
194
+ num_steps = batch_max_length + 1
195
+ if self.enc_init:
196
+ if self.seqmodel == 'BiLSTM':
197
+ init_embedding = batch_H.mean(dim=1)
198
+ encoder_hidden = batch_H
199
+ else:
200
+ encoder_hidden = batch_H
201
+ init_embedding = batch_H[:, 0, :]
202
+ h_0 = self.proj_init_h(init_embedding)
203
+ c_0 = self.proj_init_c(init_embedding)
204
+ hidden = (h_0, c_0)
205
+ else:
206
+ encoder_hidden = batch_H
207
+ hidden = (torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device),
208
+ torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device))
209
+
210
+ targets = torch.zeros(batch_size, dtype=torch.long, device=self.device) # [GO] token
211
+ probs = torch.zeros(batch_size, num_steps, self.num_classes, dtype=torch.float32, device=self.device)
212
+
213
+ if self.viz_attn:
214
+ self.alpha_stores = torch.zeros(batch_size, num_steps, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
215
+ if self.attn_type == 'coverage':
216
+ alpha_cum = torch.zeros(batch_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
217
+
218
+ self.attention_cell.reset_mem()
219
+
220
+ if is_test:
221
+ end_flag = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
222
+
223
+ for i in range(num_steps):
224
+ embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets)
225
+ output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text)
226
+ output = self.dropout(output)
227
+ if self.viz_attn:
228
+ self.alpha_stores[:, i] = alpha
229
+ if self.attn_type == 'coverage':
230
+ alpha_cum = alpha_cum + alpha
231
+ self.attention_cell.set_mem(alpha_cum)
232
+ elif self.attn_type == 'loc_aware':
233
+ self.attention_cell.set_mem(alpha)
234
+
235
+ probs_step = output
236
+ probs[:, i, :] = probs_step
237
+
238
+ if i == num_steps - 1:
239
+ break
240
+
241
+ if is_train:
242
+ if self.teacher_forcing < random.random():
243
+ _, next_input = probs_step.max(1)
244
+ targets = next_input
245
+ else:
246
+ targets = text[:, i+1]
247
+ else:
248
+ _, next_input = probs_step.max(1)
249
+ targets = next_input
250
+
251
+ if is_test:
252
+ end_flag = end_flag | (next_input == ATTN.END())
253
+ if end_flag.all():
254
+ break
255
+
256
+ _, preds_index = probs.max(2)
257
+
258
+ return preds_index, probs, None # batch_size x num_steps x num_classes
259
+
260
+ def forward(self, beam_size, batch_H, text, batch_max_length, is_train=True, is_test=False):
261
+ if is_train:
262
+ return self.forward_greedy(batch_H, text, is_train, is_test, batch_max_length)
263
+ else:
264
+ if beam_size > 1:
265
+ return self.forward_beam(batch_H, batch_max_length, beam_size)
266
+ else:
267
+ return self.forward_greedy(batch_H, text, is_train, is_test, batch_max_length)
268
+
HybridViT/module/component/prediction_head/seq2seq_v2.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import repeat
6
+ from ...converter import AttnLabelConverter as ATTN
7
+ from .addon_module import *
8
+ from .seq2seq import Attention
9
+
10
+
11
+ class AttentionV2(Attention):
12
+ def forward_beam(
13
+ self,
14
+ batch_H: torch.Tensor,
15
+ batch_max_length=25,
16
+ beam_size=4,
17
+ ):
18
+ batch_size = batch_H.size(0)
19
+ assert batch_size == 1
20
+ num_steps = batch_max_length + 1
21
+ batch_H = batch_H.squeeze(dim=0)
22
+ batch_H = repeat(batch_H, "s e -> b s e", b = beam_size)
23
+
24
+ encoder_hidden = None
25
+ if self.seqmodel in ['BiLSTM', 'VIG']:
26
+ encoder_hidden = batch_H
27
+ elif self.seqmodel == 'TFM':
28
+ encoder_hidden = batch_H[:, 1:, :]
29
+ else:
30
+ raise ValueError('seqmodel must be either BiLSTM or TFM option')
31
+
32
+ if self.enc_init:
33
+ init_embedding = None
34
+ if self.seqmodel in ['BiLSTM', 'VIG']:
35
+ init_embedding = batch_H.mean(dim=1)
36
+ elif self.seqmodel == 'TFM':
37
+ init_embedding = batch_H[:, 0, :]
38
+ else:
39
+ raise ValueError('seqmodel must be either BiLSTM or TFM option')
40
+
41
+ assert init_embedding is not None
42
+ h_0 = self.proj_init_h(init_embedding)
43
+ c_0 = self.proj_init_c(init_embedding)
44
+ hidden = (h_0, c_0)
45
+ else:
46
+ hidden = (torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device),
47
+ torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device))
48
+
49
+ assert encoder_hidden is not None
50
+
51
+ if self.attn_type == 'coverage':
52
+ alpha_cum = torch.zeros(beam_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
53
+ self.attention_cell.reset_mem()
54
+
55
+ k_prev_words = torch.LongTensor([[ATTN.START()]] * beam_size).to(self.device)
56
+ seqs = k_prev_words
57
+ targets = k_prev_words.squeeze(dim=-1)
58
+ top_k_scores = torch.zeros(beam_size, 1).to(self.device)
59
+
60
+ if self.viz_attn:
61
+ seqs_alpha = torch.ones(beam_size, 1, encoder_hidden.shape[1]).to(self.device)
62
+
63
+ complete_seqs = list()
64
+ if self.viz_attn:
65
+ complete_seqs_alpha = list()
66
+ complete_seqs_scores = list()
67
+
68
+ for step in range(num_steps):
69
+ embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets)
70
+ output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text)
71
+ output = self.dropout(output)
72
+ vocab_size = output.shape[1]
73
+
74
+ scores = F.log_softmax(output, dim=-1)
75
+ scores = top_k_scores.expand_as(scores) + scores
76
+ if step == 0:
77
+ top_k_scores, top_k_words = scores[0].topk(beam_size, 0, True, True)
78
+ else:
79
+ top_k_scores, top_k_words = scores.view(-1).topk(beam_size, 0, True, True)
80
+
81
+ prev_word_inds = top_k_words // vocab_size
82
+ next_word_inds = top_k_words % vocab_size
83
+
84
+ seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)
85
+ if self.viz_attn:
86
+ seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].permute(0, 2, 1)],
87
+ dim=1)
88
+
89
+ incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
90
+ next_word != ATTN.END()]
91
+
92
+ complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
93
+
94
+ if len(complete_inds) > 0:
95
+ complete_seqs.extend(seqs[complete_inds].tolist())
96
+ if self.viz_attn:
97
+ complete_seqs_alpha.extend(seqs_alpha[complete_inds])
98
+ complete_seqs_scores.extend(top_k_scores[complete_inds])
99
+
100
+ beam_size = beam_size - len(complete_inds)
101
+ if beam_size == 0:
102
+ break
103
+
104
+ seqs = seqs[incomplete_inds]
105
+ if self.viz_attn:
106
+ seqs_alpha = seqs_alpha[incomplete_inds]
107
+ hidden = hidden[0][prev_word_inds[incomplete_inds]], \
108
+ hidden[1][prev_word_inds[incomplete_inds]]
109
+ encoder_hidden = encoder_hidden[prev_word_inds[incomplete_inds]]
110
+ top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
111
+ targets = next_word_inds[incomplete_inds]
112
+
113
+ if self.attn_type == 'coverage':
114
+ alpha_cum = alpha_cum + alpha
115
+ alpha_cum = alpha_cum[incomplete_inds]
116
+ self.attention_cell.set_mem(alpha_cum)
117
+ elif self.attn_type == 'loc_aware':
118
+ self.attention_cell.set_mem(alpha)
119
+
120
+ if len(complete_inds) == 0:
121
+ seq = seqs[0][1:].tolist()
122
+ seq = torch.LongTensor(seq).unsqueeze(0)
123
+ score = top_k_scores[0]
124
+ if self.viz_attn:
125
+ alphas = seqs_alpha[0][1:, ...]
126
+ return seq, score, alphas
127
+ else:
128
+ return seq, score, None
129
+ else:
130
+ combine_lst = tuple(zip(complete_seqs, complete_seqs_scores))
131
+ best_ind = combine_lst.index(max(combine_lst, key=lambda x: x[1] / len(x[0]))) #https://youtu.be/XXtpJxZBa2c?t=2407
132
+ seq = complete_seqs[best_ind][1:] #not include [GO] token
133
+ seq = torch.LongTensor(seq).unsqueeze(0)
134
+ score = max(complete_seqs_scores)
135
+
136
+ if self.viz_attn:
137
+ alphas = complete_seqs_alpha[best_ind][1:, ...]
138
+ return seq, score, alphas
139
+ else:
140
+ return seq, score, None
141
+
142
+ def forward_greedy(self, batch_H, text, is_train=True, is_test=False, batch_max_length=25):
143
+ batch_size = batch_H.size(0)
144
+ num_steps = batch_max_length + 1
145
+ encoder_hidden = None
146
+ if self.seqmodel in ['BiLSTM', 'VIG']:
147
+ encoder_hidden = batch_H
148
+ elif self.seqmodel == 'TFM':
149
+ encoder_hidden = batch_H[:, 1:, :]
150
+ else:
151
+ raise ValueError('seqmodel must be either BiLSTM or TFM option')
152
+
153
+ if self.enc_init:
154
+ init_embedding = None
155
+ if self.seqmodel in ['BiLSTM', 'VIG']:
156
+ init_embedding = batch_H.mean(dim=1)
157
+ elif self.seqmodel == 'TFM':
158
+ init_embedding = batch_H[:, 0, :]
159
+ else:
160
+ raise ValueError('seqmodel must be either BiLSTM or TFM option')
161
+ h_0 = self.proj_init_h(init_embedding)
162
+ c_0 = self.proj_init_c(init_embedding)
163
+ hidden = (h_0, c_0)
164
+ else:
165
+ hidden = (torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device),
166
+ torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device))
167
+
168
+ targets = torch.zeros(batch_size, dtype=torch.long, device=self.device) # [GO] token
169
+ probs = torch.zeros(batch_size, num_steps, self.num_classes, dtype=torch.float32, device=self.device)
170
+
171
+ assert encoder_hidden is not None
172
+
173
+ if self.viz_attn:
174
+ self.alpha_stores = torch.zeros(batch_size, num_steps, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
175
+ if self.attn_type == 'coverage':
176
+ alpha_cum = torch.zeros(batch_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
177
+
178
+ self.attention_cell.reset_mem()
179
+
180
+ if is_test:
181
+ end_flag = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
182
+
183
+ for i in range(num_steps):
184
+ embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets)
185
+ output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text)
186
+ output = self.dropout(output)
187
+ if self.viz_attn:
188
+ self.alpha_stores[:, i] = alpha
189
+ if self.attn_type == 'coverage':
190
+ alpha_cum = alpha_cum + alpha
191
+ self.attention_cell.set_mem(alpha_cum)
192
+ elif self.attn_type == 'loc_aware':
193
+ self.attention_cell.set_mem(alpha)
194
+
195
+ probs_step = output
196
+ probs[:, i, :] = probs_step
197
+
198
+ if i == num_steps - 1:
199
+ break
200
+
201
+ if is_train:
202
+ if self.teacher_forcing < random.random():
203
+ _, next_input = probs_step.max(1)
204
+ targets = next_input
205
+ else:
206
+ targets = text[:, i+1]
207
+ else:
208
+ _, next_input = probs_step.max(1)
209
+ targets = next_input
210
+
211
+ if is_test:
212
+ end_flag = end_flag | (next_input == ATTN.END())
213
+ if end_flag.all():
214
+ break
215
+
216
+ _, preds_index = probs.max(2)
217
+
218
+ return preds_index, probs, None # batch_size x num_steps x num_classes
HybridViT/module/component/prediction_head/tfm.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+ from torch import FloatTensor, LongTensor
7
+ from .addon_module import WordPosEnc
8
+ from ...converter.tfm_converter import TFMLabelConverter as TFM
9
+ from ....beam import Beam
10
+
11
+ def _build_transformer_decoder(
12
+ d_model: int,
13
+ nhead: int,
14
+ num_decoder_layers: int,
15
+ dim_feedforward: int,
16
+ dropout: float,
17
+ ) -> nn.TransformerDecoder:
18
+ decoder_layer = nn.TransformerDecoderLayer(
19
+ d_model=d_model,
20
+ nhead=nhead,
21
+ dim_feedforward=dim_feedforward,
22
+ dropout=dropout,
23
+ )
24
+
25
+ decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers)
26
+
27
+ for p in decoder.parameters():
28
+ if p.dim() > 1:
29
+ nn.init.xavier_uniform_(p)
30
+
31
+ return decoder
32
+
33
+
34
+ class TransformerPrediction(nn.Module):
35
+ def __init__(
36
+ self,
37
+ d_model: int,
38
+ nhead: int,
39
+ num_decoder_layers: int,
40
+ dim_feedforward: int,
41
+ dropout: float,
42
+ num_classes: int,
43
+ max_seq_len: int,
44
+ padding_idx: int,
45
+ device: str = 'cuda:1'
46
+ ):
47
+ super().__init__()
48
+ self.max_seq_len = max_seq_len
49
+ self.padding_idx = padding_idx
50
+ self.num_classes = num_classes
51
+ self.device = device
52
+ self.word_embed = nn.Embedding(
53
+ num_classes, d_model, padding_idx=padding_idx
54
+ )
55
+
56
+ self.pos_enc = WordPosEnc(d_model=d_model)
57
+ self.d_model = d_model
58
+ self.model = _build_transformer_decoder(
59
+ d_model=d_model,
60
+ nhead=nhead,
61
+ num_decoder_layers=num_decoder_layers,
62
+ dim_feedforward=dim_feedforward,
63
+ dropout=dropout,
64
+ )
65
+
66
+ self.proj = nn.Linear(d_model, num_classes)
67
+ self.beam = Beam(
68
+ ignore_w=TFM.PAD(),
69
+ start_w=TFM.START(),
70
+ stop_w=TFM.END(),
71
+ max_len=self.max_seq_len,
72
+ device=self.device
73
+ )
74
+
75
+ def reset_beam(self):
76
+ self.beam = Beam(
77
+ ignore_w=TFM.PAD(),
78
+ start_w=TFM.START(),
79
+ stop_w=TFM.END(),
80
+ max_len=self.max_seq_len,
81
+ device=self.device
82
+ )
83
+
84
+ def _build_attention_mask(self, length):
85
+ mask = torch.full(
86
+ (length, length),
87
+ fill_value=1,
88
+ dtype=torch.bool,
89
+ device=self.device
90
+ )
91
+ mask = torch.triu(mask).transpose(0, 1)
92
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
93
+ return mask
94
+
95
+ def _embedd_tgt(self, tgt: LongTensor, tgt_len: int):
96
+ tgt_mask = self._build_attention_mask(tgt_len)
97
+ if self.training:
98
+ tgt_pad_mask = tgt == self.padding_idx
99
+ else: tgt_pad_mask = None
100
+ tgt = self.word_embed(tgt)
101
+ tgt = self.pos_enc(tgt*math.sqrt(self.d_model))
102
+ return tgt, tgt_mask, tgt_pad_mask
103
+
104
+ def forward_greedy(
105
+ self, src: FloatTensor, tgt: LongTensor, output_weight: bool = False, is_test: bool = False
106
+ ) -> FloatTensor:
107
+ if self.training:
108
+ _, l = tgt.size()
109
+ tgt, tgt_mask, tgt_pad_mask = self._embedd_tgt(tgt, l)
110
+
111
+ src = rearrange(src, "b t d -> t b d")
112
+ tgt = rearrange(tgt, "b l d -> l b d")
113
+
114
+ out = self.model(
115
+ tgt=tgt,
116
+ memory=src,
117
+ tgt_mask=tgt_mask,
118
+ tgt_key_padding_mask=tgt_pad_mask
119
+ )
120
+
121
+ out = rearrange(out, "l b d -> b l d")
122
+ out = self.proj(out)
123
+ else:
124
+ out = None
125
+ src = rearrange(src, "b t d -> t b d")
126
+
127
+ end_flag = torch.zeros(src.shape[0], dtype=torch.bool, device=self.device)
128
+
129
+ for step in range(self.max_seq_len+1):
130
+ b, l = tgt.size()
131
+ emb_tgt, tgt_mask, tgt_pad_mask = self._embedd_tgt(tgt, l)
132
+ emb_tgt = rearrange(emb_tgt, "b l d -> l b d")
133
+
134
+ out = self.model(
135
+ tgt=emb_tgt,
136
+ memory=src,
137
+ tgt_mask=tgt_mask
138
+ )
139
+
140
+ out = rearrange(out, "l b d -> b l d")
141
+ out = self.proj(out)
142
+ probs = F.softmax(out, dim=-1)
143
+ next_text = torch.argmax(probs[:, -1:, :], dim=-1)
144
+ tgt = torch.cat([tgt, next_text], dim=-1)
145
+
146
+ end_flag = end_flag | (next_text == TFM.END())
147
+ if end_flag.all() and is_test:
148
+ break
149
+
150
+ _, preds_index = out.max(dim=2)
151
+ return preds_index, out
152
+
153
+ def forward_beam(self,
154
+ src: torch.FloatTensor,
155
+ beam_size: int
156
+ ):
157
+ assert src.size(0) == 1, f'beam search should only have signle source, encounter with batch size: {src.size(0)}'
158
+ out = None
159
+ src = src.squeeze(0)
160
+
161
+ for step in range(self.max_seq_len + 1):
162
+ hypotheses = self.beam.hypotheses
163
+ hyp_num = hypotheses.size(0)
164
+ l = hypotheses.size(1)
165
+ assert hyp_num <= beam_size, f"hyp_num: {hyp_num}, beam_size: {beam_size}"
166
+
167
+ emb_tgt = self.word_embed(hypotheses)
168
+ emb_tgt = self.pos_enc(emb_tgt*math.sqrt(self.d_model))
169
+ tgt_mask = self._build_attention_mask(l)
170
+ emb_tgt = rearrange(emb_tgt, "b l d -> l b d")
171
+
172
+ exp_src = repeat(src.squeeze(1), "s e -> s b e", b=hyp_num)
173
+
174
+ out = self.model(
175
+ tgt=emb_tgt,
176
+ memory=exp_src,
177
+ tgt_mask=tgt_mask
178
+ )
179
+
180
+ out = rearrange(out, "l b d -> b l d")
181
+ out = self.proj(out)
182
+ log_prob = F.log_softmax(out[:, step, :], dim=-1)
183
+ new_hypotheses, new_hyp_scores = self.beam.advance(log_prob, step, beam_size=beam_size)
184
+
185
+ if self.beam.done(beam_size):
186
+ break
187
+
188
+ self.beam.set_current_state(new_hypotheses)
189
+ self.beam.set_current_score(new_hyp_scores)
190
+
191
+ self.beam.set_hypothesis()
192
+ best_hyp = max(self.beam.completed_hypotheses, key=lambda h: h.score / len(h))
193
+ output = best_hyp.seq
194
+ output = torch.LongTensor(output).unsqueeze(0)
195
+ score = best_hyp.score
196
+
197
+ return output, score
198
+
199
+ def forward(self, beam_size, batch_H, text, is_test):
200
+ if self.training:
201
+ return self.forward_greedy(batch_H, text)
202
+ else:
203
+ if beam_size > 1:
204
+ return self.forward_beam(batch_H, beam_size)
205
+ else:
206
+ return self.forward_greedy(batch_H, text, is_test = is_test)
207
+
HybridViT/module/component/seq_modeling/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .bilstm import *
2
+ from .vit_encoder import *
HybridViT/module/component/seq_modeling/addon_module/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .patchembed import *
HybridViT/module/component/seq_modeling/addon_module/patchembed.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch.nn as nn
3
+ import torch
4
+ from torch.nn import functional as F
5
+ from timm.models.layers.helpers import to_2tuple
6
+ from typing import Tuple, Union, List
7
+
8
+ class PatchEmbed(nn.Module):
9
+ """ Image to Patch Embedding
10
+ """
11
+ def __init__(self, img_size=(224, 224), patch_size=16, in_chans=3, embed_dim=768):
12
+ super().__init__()
13
+ assert isinstance(img_size, tuple)
14
+ patch_size = to_2tuple(patch_size)
15
+ div_h, mod_h = divmod(img_size[0], patch_size[0])
16
+ div_w, mod_w = divmod(img_size[1], patch_size[1])
17
+ self.img_size = (patch_size[0]*(div_h + (1 if mod_h > 0 else 0)), \
18
+ patch_size[1]*(div_w + (1 if mod_w > 0 else 0)))
19
+ self.grid_size = (self.img_size[0] // patch_size[0], self.img_size[1] // patch_size[1])
20
+ self.patch_size = patch_size
21
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
22
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
23
+
24
+ def forward(self, x):
25
+ _, _, H, W = x.shape
26
+ div_h, mod_h = divmod(H, self.patch_size[0])
27
+ div_w, mod_w = divmod(W, self.patch_size[1])
28
+ pad_H =self.patch_size[0]*(div_h + (1 if mod_h > 0 else 0)) - H
29
+ pad_W = self.patch_size[1]*(div_w + (1 if mod_w > 0 else 0)) - W
30
+ x = F.pad(x, (0, pad_W, 0 , pad_H))
31
+ assert x.shape[2] % self.patch_size[0] == 0 and x.shape[3] % self.patch_size[1] == 0
32
+ proj_x = self.proj(x).flatten(2).transpose(1, 2)
33
+ return proj_x, {'height': x.shape[2], 'width': x.shape[3]}, (x.shape[2] != self.img_size[0] or x.shape[3] != self.img_size[1])
34
+
35
+ class HybridEmbed(nn.Module):
36
+ """ CNN Feature Map Embedding
37
+ Extract feature map from CNN, flatten, project to embedding dim.
38
+ """
39
+ def __init__(self, backbone, img_size: Tuple[int], patch_size=Union[List, int], feature_size=None, in_chans=3, embed_dim=768):
40
+ super().__init__()
41
+ assert isinstance(backbone, nn.Module)
42
+ if isinstance(patch_size, int):
43
+ patch_size = to_2tuple(patch_size)
44
+ else:
45
+ patch_size = tuple(patch_size)
46
+ self.img_size = img_size
47
+ self.patch_size = patch_size
48
+ self.backbone = backbone
49
+ if feature_size is None:
50
+ with torch.no_grad():
51
+ # NOTE Most reliable way of determining output dims is to run forward pass
52
+ training = backbone.training
53
+ if training:
54
+ backbone.eval()
55
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
56
+ if isinstance(o, (list, tuple)):
57
+ o = o[-1] # last feature if backbone outputs list/tuple of features
58
+ feature_size = o.shape[-2:]
59
+ feature_dim = o.shape[1]
60
+ backbone.train(training)
61
+ else:
62
+ feature_size = to_2tuple(feature_size)
63
+ if hasattr(self.backbone, 'feature_info'):
64
+ feature_dim = self.backbone.feature_info.channels()[-1]
65
+ else:
66
+ feature_dim = self.backbone.num_features
67
+
68
+ assert feature_size[0] >= patch_size[0] and feature_size[1] >= patch_size[1]
69
+
70
+ div_h, mod_h = divmod(feature_size[0], patch_size[0])
71
+ div_w, mod_w = divmod(feature_size[1], patch_size[1])
72
+
73
+ self.feature_size = (patch_size[0]*(div_h + (1 if mod_h > 0 else 0)), patch_size[1]*(div_w + (1 if mod_w > 0 else 0)))
74
+ assert self.feature_size[0] % patch_size[0] == 0 and self.feature_size[1] % patch_size[1] == 0
75
+ self.grid_size = (self.feature_size[0] // patch_size[0], self.feature_size[1] // patch_size[1])
76
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
77
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
78
+
79
+ def forward(self, x):
80
+ origin_size = x.shape[-2:]
81
+ x = self.backbone(x)
82
+ f_h, f_w = x.shape[2:]
83
+ # assert f_h >= self.patch_size[0] and f_w >= self.patch_size[1]
84
+
85
+ div_h, mod_h = divmod(f_h, self.patch_size[0])
86
+ div_w, mod_w = divmod(f_w, self.patch_size[1])
87
+
88
+ pad_H =self.patch_size[0]*(div_h + (1 if mod_h > 0 else 0)) - f_h
89
+ pad_W = self.patch_size[1]*(div_w + (1 if mod_w > 0 else 0)) - f_w
90
+ x = F.pad(x, (0, pad_W, 0 , pad_H))
91
+
92
+ assert x.shape[2] % self.patch_size[0] == 0 and x.shape[3] % self.patch_size[1] == 0
93
+ if isinstance(x, (list, tuple)):
94
+ x = x[-1] # last feature if backbone outputs list/tuple of features
95
+
96
+ proj_x = self.proj(x).flatten(2).transpose(1, 2)
97
+ return proj_x, (pad_W, pad_H), {'height': x.shape[2], 'width': x.shape[3]}, (x.shape[2] != self.feature_size[0] or x.shape[3] != self.feature_size[1])
98
+
99
+ class HybridEmbed1D(nn.Module):
100
+ """ CNN Feature Map Embedding which using 1D embed patching
101
+ from https://arxiv.org/pdf/2111.08314.pdf, which benefits for text recognition task.Check paper for more detail
102
+ Extract feature map from CNN, flatten, project to embedding dim.
103
+ """
104
+ def __init__(self, backbone, img_size: Tuple[int], feature_size=None, patch_size=1, in_chans=3, embed_dim=768):
105
+ super().__init__()
106
+ assert isinstance(backbone, nn.Module)
107
+ self.img_size = img_size
108
+ self.backbone = backbone
109
+ self.embed_dim = embed_dim
110
+ if feature_size is None:
111
+ with torch.no_grad():
112
+ # NOTE Most reliable way of determining output dims is to run forward pass
113
+ training = backbone.training
114
+ if training:
115
+ backbone.eval()
116
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
117
+ if isinstance(o, (list, tuple)):
118
+ o = o[-1] # last feature if backbone outputs list/tuple of features
119
+ feature_size = o.shape[-2:]
120
+ feature_dim = o.shape[1]
121
+ backbone.train(training)
122
+ else:
123
+ feature_size = to_2tuple(feature_size)
124
+ if hasattr(self.backbone, 'feature_info'):
125
+ feature_dim = self.backbone.feature_info.channels()[-1]
126
+ else:
127
+ feature_dim = self.backbone.num_features
128
+
129
+ self.window_width = patch_size
130
+ assert feature_size[1] >= self.window_width
131
+ div_w, mod_w = divmod(feature_size[1], self.window_width)
132
+ self.feature_size = (feature_size[0], self.window_width*(div_w + (1 if mod_w > 0 else 0)))
133
+ assert self.feature_size[1] % self.window_width == 0
134
+ self.grid_size = (1, self.feature_size[1] // self.window_width)
135
+ self.num_patches = self.grid_size[1]
136
+ self.proj = nn.Conv1d(feature_dim, embed_dim, kernel_size=self.window_width, stride=self.window_width, bias=True)
137
+
138
+ def forward(self, x):
139
+ batch_size = x.shape[0]
140
+ x = self.backbone(x)
141
+ f_h, f_w = x.shape[2:]
142
+ assert f_w >= self.window_width
143
+
144
+ div_w, mod_w = divmod(f_w, self.window_width)
145
+ pad_W = self.window_width*(div_w + (1 if mod_w > 0 else 0)) - f_w
146
+
147
+ x = F.pad(x, (0, pad_W))
148
+ assert x.shape[3] % self.window_width == 0
149
+
150
+ if isinstance(x, (list, tuple)):
151
+ x = x[-1] # last feature if backbone outputs list/tuple of features
152
+
153
+ proj_x = torch.zeros(batch_size, self.embed_dim, f_h, x.shape[3]//self.window_width, device=x.device, dtype=x.dtype)
154
+
155
+ for i in range(f_h):
156
+ proj = self.proj(x[:, :, i, :])
157
+ proj_x[:, :, i, :] = proj
158
+
159
+ proj_x = proj_x.mean(dim=2).transpose(1, 2) #BCHW->BCW
160
+
161
+ return proj_x, (pad_W, ), {'height': x.shape[2], 'width': x.shape[3]}, (x.shape[2] != self.feature_size[0] or x.shape[3] != self.feature_size[1])
HybridViT/module/component/seq_modeling/bilstm.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ __all__ = ['BiLSTM_Seq_Modeling', 'BidirectionalLSTM']
4
+
5
+ class BidirectionalLSTM(nn.Module):
6
+ def __init__(self, input_size, hidden_size, output_size):
7
+ super(BidirectionalLSTM, self).__init__()
8
+ self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
9
+ self.linear = nn.Linear(hidden_size * 2, output_size)
10
+
11
+ def forward(self, input):
12
+ """
13
+ input : visual feature [batch_size x T x input_size]
14
+ output : contextual feature [batch_size x T x output_size]
15
+ """
16
+ self.rnn.flatten_parameters()
17
+ recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
18
+ output = self.linear(recurrent) # batch_size x T x output_size
19
+ return output
20
+
21
+ class BiLSTM_Seq_Modeling(nn.Module):
22
+ def __init__(self, num_layers, input_size, hidden_size, output_size):
23
+ super(BiLSTM_Seq_Modeling, self).__init__()
24
+ self.num_layers = num_layers
25
+ layers = []
26
+ layers += [BidirectionalLSTM(input_size, hidden_size, hidden_size)]
27
+ for i in range(num_layers-2):
28
+ layers.append(BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
29
+ layers.append(BidirectionalLSTM(hidden_size, hidden_size, output_size))
30
+ self.lstm = nn.Sequential(*layers)
31
+
32
+ def forward(self, input):
33
+ return self.lstm(input)
HybridViT/module/component/seq_modeling/vit/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import warnings
4
+
5
+
6
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
7
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
8
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
9
+ def norm_cdf(x):
10
+ # Computes standard normal cumulative distribution function
11
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
12
+
13
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
14
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
15
+ "The distribution of values may be incorrect.",
16
+ stacklevel=2)
17
+
18
+ with torch.no_grad():
19
+ # Values are generated by using a truncated uniform distribution and
20
+ # then using the inverse CDF for the normal distribution.
21
+ # Get upper and lower cdf values
22
+ l = norm_cdf((a - mean) / std)
23
+ u = norm_cdf((b - mean) / std)
24
+
25
+ # Uniformly fill tensor with values from [l, u], then translate to
26
+ # [2l-1, 2u-1].
27
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
28
+
29
+ # Use inverse cdf transform for normal distribution to get truncated
30
+ # standard normal
31
+ tensor.erfinv_()
32
+
33
+ # Transform to proper mean, std
34
+ tensor.mul_(std * math.sqrt(2.))
35
+ tensor.add_(mean)
36
+
37
+ # Clamp to ensure it's in the proper range
38
+ tensor.clamp_(min=a, max=b)
39
+ return tensor
40
+
41
+
42
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
43
+ r"""Fills the input Tensor with values drawn from a truncated
44
+ normal distribution. The values are effectively drawn from the
45
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
46
+ with values outside :math:`[a, b]` redrawn until they are within
47
+ the bounds. The method used for generating the random values works
48
+ best when :math:`a \leq \text{mean} \leq b`.
49
+ Args:
50
+ tensor: an n-dimensional `torch.Tensor`
51
+ mean: the mean of the normal distribution
52
+ std: the standard deviation of the normal distribution
53
+ a: the minimum cutoff value
54
+ b: the maximum cutoff value
55
+ Examples:
56
+ >>> w = torch.empty(3, 5)
57
+ >>> nn.init.trunc_normal_(w)
58
+ """
59
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
HybridViT/module/component/seq_modeling/vit/vision_transformer.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from functools import partial
4
+ from collections import OrderedDict
5
+ from ...common import DropPath
6
+ from .utils import trunc_normal_
7
+
8
+
9
+ class Mlp(nn.Module):
10
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
11
+ super().__init__()
12
+ out_features = out_features or in_features
13
+ hidden_features = hidden_features or in_features
14
+ self.fc1 = nn.Linear(in_features, hidden_features)
15
+ self.act = act_layer()
16
+ self.fc2 = nn.Linear(hidden_features, out_features)
17
+ self.drop = nn.Dropout(drop)
18
+
19
+ def forward(self, x):
20
+ x = self.fc1(x)
21
+ x = self.act(x)
22
+ x = self.drop(x)
23
+ x = self.fc2(x)
24
+ x = self.drop(x)
25
+ return x
26
+
27
+
28
+ class ConvFFN(nn.Module):
29
+ def __init__(self, *args, **kwargs) -> None:
30
+ super().__init__(*args, **kwargs)
31
+
32
+
33
+ class Attention(nn.Module):
34
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
35
+ super().__init__()
36
+ self.num_heads = num_heads
37
+ head_dim = dim // num_heads
38
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
39
+ self.scale = qk_scale or head_dim ** -0.5
40
+
41
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
42
+ self.attn_drop = nn.Dropout(attn_drop)
43
+ self.proj = nn.Linear(dim, dim)
44
+ self.proj_drop = nn.Dropout(proj_drop)
45
+
46
+ def forward(self, x):
47
+ B, N, C = x.shape
48
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
49
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
50
+
51
+ attn = (q @ k.transpose(-2, -1)) * self.scale
52
+ attn = attn.softmax(dim=-1)
53
+ attn = self.attn_drop(attn)
54
+
55
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
56
+ x = self.proj(x)
57
+ x = self.proj_drop(x)
58
+ return x
59
+
60
+
61
+ class Block(nn.Module):
62
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
63
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
64
+ super().__init__()
65
+ self.norm1 = norm_layer(dim)
66
+ self.attn = Attention(
67
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
68
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
69
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
70
+ self.norm2 = norm_layer(dim)
71
+ mlp_hidden_dim = int(dim * mlp_ratio)
72
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
73
+
74
+ def forward(self, x):
75
+ x = x + self.drop_path(self.attn(self.norm1(x)))
76
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
77
+ return x
78
+
79
+
80
+ class VisionTransformer(nn.Module):
81
+ """ Vision Transformer
82
+
83
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
84
+ https://arxiv.org/abs/2010.11929
85
+ """
86
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
87
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
88
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None):
89
+ """
90
+ Args:
91
+ img_size (int, tuple): input image size
92
+ patch_size (int, tuple): patch size
93
+ in_chans (int): number of input channels
94
+ num_classes (int): number of classes for classification head
95
+ embed_dim (int): embedding dimension
96
+ depth (int): depth of transformer
97
+ num_heads (int): number of attention heads
98
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
99
+ qkv_bias (bool): enable bias for qkv if True
100
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
101
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
102
+ drop_rate (float): dropout rate
103
+ attn_drop_rate (float): attention dropout rate
104
+ drop_path_rate (float): stochastic depth rate
105
+ hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
106
+ norm_layer: (nn.Module): normalization layer
107
+ """
108
+ super().__init__()
109
+ self.num_classes = num_classes
110
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
111
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
112
+
113
+ self.patch_embed = None
114
+ num_patches = getattr(self.patch_embed, 'num_patches', 128)
115
+
116
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
117
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
118
+ self.pos_drop = nn.Dropout(p=drop_rate)
119
+
120
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
121
+ self.blocks = nn.ModuleList([
122
+ Block(
123
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
124
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
125
+ for i in range(depth)])
126
+ self.norm = norm_layer(embed_dim)
127
+
128
+ # Representation layer
129
+ if representation_size:
130
+ self.num_features = representation_size
131
+ self.pre_logits = nn.Sequential(OrderedDict([
132
+ ('fc', nn.Linear(embed_dim, representation_size)),
133
+ ('act', nn.Tanh())
134
+ ]))
135
+ else:
136
+ self.pre_logits = nn.Identity()
137
+
138
+ # Classifier head
139
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
140
+
141
+ trunc_normal_(self.pos_embed, std=.02)
142
+ trunc_normal_(self.cls_token, std=.02)
143
+ self.apply(self._init_weights)
144
+
145
+ def _init_weights(self, m):
146
+ if isinstance(m, nn.Linear):
147
+ trunc_normal_(m.weight, std=.02)
148
+ if isinstance(m, nn.Linear) and m.bias is not None:
149
+ nn.init.constant_(m.bias, 0)
150
+ elif isinstance(m, nn.LayerNorm):
151
+ nn.init.constant_(m.bias, 0)
152
+ nn.init.constant_(m.weight, 1.0)
153
+
154
+ @torch.jit.ignore
155
+ def no_weight_decay(self):
156
+ return {'pos_embed', 'cls_token'}
157
+
158
+ def get_classifier(self):
159
+ return self.head
160
+
161
+ def reset_classifier(self, num_classes, global_pool=''):
162
+ self.num_classes = num_classes
163
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
164
+
165
+ def forward_features(self, x):
166
+ B = x.shape[0]
167
+ x = self.patch_embed(x)
168
+
169
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
170
+ x = torch.cat((cls_tokens, x), dim=1)
171
+ x = x + self.pos_embed
172
+ x = self.pos_drop(x)
173
+
174
+ for blk in self.blocks:
175
+ x = blk(x)
176
+
177
+ x = self.norm(x)[:, 0]
178
+ x = self.pre_logits(x)
179
+ return x
180
+
181
+ def forward(self, x):
182
+ x = self.forward_features(x)
183
+ x = self.head(x)
184
+ return x
HybridViT/module/component/seq_modeling/vit_encoder.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from .vit.utils import trunc_normal_
5
+ from .vit.vision_transformer import VisionTransformer
6
+ from ..feature_extractor.clova_impl import ResNet_FeatureExtractor
7
+ from .addon_module import *
8
+ from ..common.mae_posembed import get_2d_sincos_pos_embed
9
+
10
+ __all__ = ['ViTEncoder', 'ViTEncoderV2', 'ViTEncoderV3', 'TRIGBaseEncoder', 'create_vit_modeling']
11
+
12
+ class ViTEncoder(VisionTransformer):
13
+ '''
14
+ '''
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__(*args, **kwargs)
17
+
18
+ if kwargs['hybrid_backbone'] is None:
19
+ self.patch_embed = PatchEmbed(
20
+ img_size=kwargs['img_size'],
21
+ in_chans=kwargs['in_chans'],
22
+ patch_size=kwargs['patch_size'],
23
+ embed_dim=kwargs['embed_dim'],
24
+ )
25
+ else:
26
+ self.patch_embed = HybridEmbed(
27
+ backbone=kwargs['hybrid_backbone'],
28
+ img_size=kwargs['img_size'],
29
+ in_chans=kwargs['in_chans'],
30
+ patch_size=kwargs['patch_size'],
31
+ embed_dim=kwargs['embed_dim'],
32
+ )
33
+ num_patches = self.patch_embed.num_patches
34
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']))
35
+ self.emb_height = self.patch_embed.grid_size[0]
36
+ self.emb_width = self.patch_embed.grid_size[1]
37
+ trunc_normal_(self.pos_embed, std=.02)
38
+ self.apply(self._init_weights)
39
+
40
+ def reset_classifier(self, num_classes):
41
+ self.num_classes = num_classes
42
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
43
+
44
+ def interpolating_pos_embedding(self, embedding, height, width):
45
+ """
46
+ Source:
47
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
48
+ """
49
+ npatch = embedding.shape[1] - 1
50
+ N = self.pos_embed.shape[1] - 1
51
+ if npatch == N and height == width:
52
+ return self.pos_embed
53
+
54
+ class_pos_embedding = self.pos_embed[:, 0]
55
+ patch_pos_embedding = self.pos_embed[:, 1:]
56
+ dim = self.pos_embed.shape[-1]
57
+
58
+ h0 = height // self.patch_embed.patch_size[0]
59
+ w0 = width // self.patch_embed.patch_size[1]
60
+ #add a small number to avo_id floating point error
61
+ # https://github.com/facebookresearch/dino/issues/8
62
+
63
+ h0 = h0 + 0.1
64
+ w0 = w0 + 0.1
65
+
66
+ patch_pos_embedding = nn.functional.interpolate(
67
+ patch_pos_embedding.reshape(1, self.emb_height, self.emb_width, dim).permute(0, 3, 1, 2),
68
+ scale_factor=(h0 / self.emb_height, w0 / self.emb_width),
69
+ mode='bicubic',
70
+ align_corners=False
71
+ )
72
+ assert int(h0) == patch_pos_embedding.shape[-2] and int(w0) == patch_pos_embedding.shape[-1]
73
+ patch_pos_embedding = patch_pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
74
+ class_pos_embedding = class_pos_embedding.unsqueeze(0)
75
+
76
+ return torch.cat((class_pos_embedding, patch_pos_embedding), dim=1)
77
+
78
+ def forward_features(self, x):
79
+ B, C, _, _ = x.shape
80
+
81
+ x, pad_info, size, interpolating_pos = self.patch_embed(x)
82
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
83
+ x = torch.cat((cls_tokens, x), dim=1)
84
+
85
+ if interpolating_pos:
86
+ x = x + self.interpolating_pos_embedding(x, size['height'], size['width'])
87
+ else:
88
+ x = x + self.pos_embed
89
+
90
+ x = self.pos_drop(x)
91
+
92
+ for blk in self.blocks:
93
+ x = blk(x)
94
+
95
+ x = self.norm(x)
96
+
97
+ return x, pad_info, size
98
+
99
+
100
+ class TRIGBaseEncoder(ViTEncoder):
101
+ '''
102
+ https://arxiv.org/pdf/2111.08314.pdf
103
+ '''
104
+ def __init__(self, *args, **kwargs):
105
+ super().__init__(*args, **kwargs)
106
+ self.patch_embed = HybridEmbed1D(
107
+ backbone=kwargs['hybrid_backbone'],
108
+ img_size=kwargs['img_size'],
109
+ in_chans=kwargs['in_chans'],
110
+ patch_size=kwargs['patch_size'],
111
+ embed_dim=kwargs['embed_dim'],
112
+ )
113
+ num_patches = self.patch_embed.num_patches
114
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']))
115
+ self.emb_height = 1
116
+ self.emb_width = self.patch_embed.grid_size[1]
117
+ trunc_normal_(self.pos_embed, std=.02)
118
+ self.apply(self._init_weights)
119
+
120
+ def interpolating_pos_embedding(self, embedding, height, width):
121
+ """
122
+ Source:
123
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
124
+ """
125
+ npatch = embedding.shape[1] - 1
126
+ N = self.pos_embed.shape[1] - 1
127
+ if npatch == N and height == width:
128
+ return self.pos_embed
129
+
130
+ class_pos_embedding = self.pos_embed[:, 0]
131
+ patch_pos_embedding = self.pos_embed[:, 1:]
132
+ dim = self.pos_embed.shape[-1]
133
+
134
+ w0 = width // self.patch_embed.window_width
135
+
136
+ #add a small number to avoid floating point error
137
+ # https://github.com/facebookresearch/dino/issues/8
138
+
139
+ w0 = w0 + 0.1
140
+
141
+ patch_pos_embedding = nn.functional.interpolate(
142
+ patch_pos_embedding.reshape(1, self.emb_height, self.emb_width, dim).permute(0, 3, 1, 2),
143
+ scale_factor=(1, w0 / self.emb_width),
144
+ mode='bicubic',
145
+ align_corners=False
146
+ )
147
+
148
+ assert int(w0) == patch_pos_embedding.shape[-1]
149
+ patch_pos_embedding = patch_pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
150
+ class_pos_embedding = class_pos_embedding.unsqueeze(0)
151
+
152
+ return torch.cat((class_pos_embedding, patch_pos_embedding), dim=1)
153
+
154
+ def forward_features(self, x):
155
+ B, _, _, _ = x.shape
156
+ x, padinfo, size, interpolating_pos = self.patch_embed(x)
157
+
158
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
159
+
160
+ x = torch.cat((cls_tokens, x), dim=1) #cls_tokens is init_embedding in TRIG paper
161
+
162
+ if interpolating_pos:
163
+ x = x + self.interpolating_pos_embedding(x, size['height'], size['width'])
164
+ else:
165
+ x = x + self.pos_embed
166
+
167
+ x = self.pos_drop(x)
168
+
169
+ for blk in self.blocks:
170
+ x = blk(x)
171
+
172
+ x = self.norm(x)
173
+
174
+ return x, padinfo, size
175
+
176
+
177
+ class ViTEncoderV2(ViTEncoder):
178
+ def forward(self, x):
179
+ B, _, _, _ = x.shape
180
+
181
+ x, pad_info, size, _ = self.patch_embed(x)
182
+ _, numpatches, *_ = x.shape
183
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
184
+ x = torch.cat((cls_tokens, x), dim=1)
185
+
186
+ x = x + self.pos_embed[:, :(numpatches + 1)]
187
+ x = self.pos_drop(x)
188
+
189
+ for blk in self.blocks:
190
+ x = blk(x)
191
+
192
+ x = self.norm(x)
193
+
194
+ return x, pad_info, size
195
+
196
+ class ViTEncoderV3(ViTEncoder):
197
+ def __init__(self, *args, **kwargs):
198
+ super().__init__(*args, **kwargs)
199
+ if hasattr(self, 'pos_embed'):
200
+ del self.pos_embed
201
+ num_patches = self.patch_embed.num_patches
202
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']), requires_grad=False)
203
+ self.initialize_posembed()
204
+
205
+ def initialize_posembed(self):
206
+ pos_embed = get_2d_sincos_pos_embed(
207
+ self.pos_embed.shape[-1],
208
+ self.patch_embed.grid_size[0],
209
+ self.patch_embed.grid_size[1],
210
+ cls_token=True
211
+ )
212
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
213
+
214
+ def forward(self, x):
215
+ B, _, _, _ = x.shape
216
+
217
+ x, pad_info, size, _ = self.patch_embed(x)
218
+ _, numpatches, *_ = x.shape
219
+
220
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
221
+ x = torch.cat((cls_tokens, x), dim=1)
222
+
223
+ x = x + self.pos_embed[:, :(numpatches + 1)]
224
+ x = self.pos_drop(x)
225
+
226
+ for blk in self.blocks:
227
+ x = blk(x)
228
+
229
+ x = self.norm(x)
230
+
231
+ return x, pad_info, size
232
+
233
+ def create_vit_modeling(opt):
234
+ seq_modeling = opt['SequenceModeling']['params']
235
+ if seq_modeling['backbone'] is not None:
236
+ if seq_modeling['backbone']['name'] == 'resnet':
237
+ param_kwargs = dict()
238
+ if seq_modeling['backbone'].get('pretrained', None) is not None:
239
+ param_kwargs['pretrained'] = seq_modeling['backbone']['pretrained']
240
+ if seq_modeling['backbone'].get('weight_dir', None) is not None:
241
+ param_kwargs['weight_dir'] = seq_modeling['backbone']['weight_dir']
242
+ print('kwargs', param_kwargs)
243
+
244
+ backbone = ResNet_FeatureExtractor(
245
+ seq_modeling['backbone']['input_channel'],
246
+ seq_modeling['backbone']['output_channel'],
247
+ seq_modeling['backbone']['gcb'],
248
+ **param_kwargs
249
+ )
250
+ elif seq_modeling['backbone']['name'] == 'cnn':
251
+ backbone = None
252
+ else: backbone = None
253
+ max_dimension = (opt['imgH'], opt['max_dimension'][1]) if opt['imgH'] else opt['max_dimension']
254
+ if seq_modeling['patching_style'] == '2d':
255
+ if seq_modeling.get('fix_embed', False):
256
+ encoder = ViTEncoderV3
257
+ else:
258
+ if not seq_modeling.get('interpolate_embed', True):
259
+ encoder = ViTEncoderV2
260
+ else:
261
+ encoder = ViTEncoder
262
+ else:
263
+ encoder = TRIGBaseEncoder
264
+
265
+ encoder_seq_modeling = encoder(
266
+ img_size=max_dimension,
267
+ patch_size=seq_modeling['patch_size'],
268
+ in_chans=seq_modeling['input_channel'],
269
+ depth=seq_modeling['depth'],
270
+ num_classes=0,
271
+ embed_dim=seq_modeling['hidden_size'],
272
+ num_heads=seq_modeling['num_heads'],
273
+ hybrid_backbone=backbone
274
+ )
275
+
276
+ return encoder_seq_modeling
HybridViT/module/converter/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .builder import create_converter
2
+ from .attn_converter import AttnLabelConverter
3
+ from .tfm_converter import TFMLabelConverter
HybridViT/module/converter/attn_converter.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ class AttnLabelConverter(object):
5
+ """ Convert between text-label and text-index """
6
+
7
+ list_token = ['[GO]', '[s]', '[UNK]']
8
+ def __init__(self, character, device):
9
+ list_character = character
10
+ self.character = AttnLabelConverter.list_token + list_character
11
+
12
+ self.device = device
13
+ self.dict = {}
14
+ for i, char in enumerate(self.character):
15
+ self.dict[char] = i
16
+ self.ignore_idx = self.dict['[GO]']
17
+
18
+ @staticmethod
19
+ def START() -> int:
20
+ return AttnLabelConverter.list_token.index('[GO]')
21
+
22
+ @staticmethod
23
+ def END() -> int:
24
+ return AttnLabelConverter.list_token.index('[s]')
25
+
26
+ @staticmethod
27
+ def UNK() -> int:
28
+ return AttnLabelConverter.list_token.index('[UNK]')
29
+
30
+ def encode(self, text, batch_max_length=25):
31
+ length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
32
+ # batch_max_length = max(length) # this is not allowed for multi-gpu setting
33
+ batch_max_length += 1
34
+ # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
35
+ batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
36
+ for i, t in enumerate(text):
37
+ text = list(t)
38
+
39
+ if len(text) > batch_max_length:
40
+ text = text[:(batch_max_length-1)]
41
+
42
+ text.append('[s]')
43
+ text = [self.dict[char] if char in self.dict else self.dict['[UNK]'] for char in text]
44
+
45
+ batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
46
+ return (batch_text.to(self.device), torch.IntTensor(length).to(self.device))
47
+
48
+ def decode(self, text_index, token_level='word'):
49
+ """ convert text-index into text-label. """
50
+ texts = []
51
+ batch_size = text_index.shape[0]
52
+ for index in range(batch_size):
53
+ if token_level == 'word':
54
+ text = ' '.join([self.character[i] for i in text_index[index, :]])
55
+ else:
56
+ text = ''.join([self.character[i] for i in text_index[index, :]])
57
+ texts.append(text)
58
+ return texts
59
+
60
+ def detokenize(self, token_ids):
61
+ """convert token ids to list of token"""
62
+ b_toks = []
63
+ for tok in token_ids:
64
+ toks = []
65
+ for i in tok:
66
+ if self.character[i] == '[s]':
67
+ break
68
+ toks.append(self.character[i])
69
+ b_toks.append(toks)
70
+
71
+ return b_toks
HybridViT/module/converter/builder.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .attn_converter import AttnLabelConverter
2
+
3
+ def create_converter(config, device):
4
+ if 'Attn' in config['Prediction']['name']:
5
+ converter = AttnLabelConverter(config['character'], device)
6
+ return converter
HybridViT/module/converter/tfm_converter.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class TFMLabelConverter(object):
6
+ """ Convert between text-label and text-index """
7
+
8
+ list_token = ['[PAD]', '[GO]', '[s]', '[UNK]']
9
+ def __init__(self, character, device):
10
+ list_character = character
11
+ self.character = TFMLabelConverter.list_token + list_character
12
+
13
+ self.device = device
14
+ self.dict = {}
15
+ for i, char in enumerate(self.character):
16
+ self.dict[char] = i
17
+ self.ignore_idx = self.dict['[PAD]']
18
+
19
+ @staticmethod
20
+ def START() -> int:
21
+ return TFMLabelConverter.list_token.index('[GO]')
22
+
23
+ @staticmethod
24
+ def END() -> int:
25
+ return TFMLabelConverter.list_token.index('[s]')
26
+
27
+ @staticmethod
28
+ def UNK() -> int:
29
+ return TFMLabelConverter.list_token.index('[UNK]')
30
+
31
+ @staticmethod
32
+ def PAD() -> int:
33
+ return TFMLabelConverter.list_token.index('[PAD]')
34
+
35
+ def encode(self, text, batch_max_length=25):
36
+ length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
37
+ batch_max_length += 1
38
+ batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(self.ignore_idx)
39
+ for i, t in enumerate(text):
40
+ text = list(t)
41
+
42
+ if len(text) > batch_max_length:
43
+ text = text[:(batch_max_length-1)]
44
+
45
+ text.append('[s]')
46
+ text = [self.dict[char] if char in self.dict else self.dict['[UNK]'] for char in text]
47
+ batch_text[i][0] = torch.LongTensor([self.dict['[GO]']])
48
+ batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
49
+ return (batch_text.to(self.device), torch.IntTensor(length).to(self.device))
50
+
51
+ def decode(self, text_index, token_level='word'):
52
+ """ convert text-index into text-label. """
53
+ texts = []
54
+ batch_size = text_index.shape[0]
55
+ for index in range(batch_size):
56
+ if token_level == 'word':
57
+ text = ' '.join([self.character[i] for i in text_index[index, :]])
58
+ else:
59
+ text = ''.join([self.character[i] for i in text_index[index, :]])
60
+ texts.append(text)
61
+ return texts
62
+
63
+ def detokenize(self, token_ids):
64
+ """convert token ids to list of token"""
65
+ b_toks = []
66
+ for tok in token_ids:
67
+ toks = []
68
+ for i in tok:
69
+ if self.character[i] == '[s]':
70
+ break
71
+ toks.append(self.character[i])
72
+ b_toks.append(toks)
73
+
74
+ return b_toks
75
+
76
+ if __name__ == '__main__':
77
+ vocab = ['S', 'ố', ' ', '2', '5', '3', 'đ', 'ư', 'ờ', 'n', 'g', 'T', 'r', 'ầ', 'P', 'h', 'ú', ',', 'ị', 't', 'ấ', 'N', 'a', 'm', 'á', 'c', 'H', 'u', 'y', 'ệ', 'ả', 'i', 'D', 'ơ', '8', '9', 'Đ', 'B', 'ộ', 'L', 'ĩ', '6', 'Q', 'ậ', 'ì', 'ạ', 'ồ', 'C', 'í', 'M', '4', 'E', '/', 'K', 'p', '1', 'A', 'x', 'ặ', 'ễ', '0', 'â', 'à', 'ế', 'ừ', 'ê', '-', '7', 'o', 'V', 'ô', 'ã', 'G', 'ớ', 'Y', 'I', 'ề', 'ò', 'l', 'R', 'ỹ', 'ủ', 'X', "'", 'e', 'ắ', 'ổ', 'ằ', 'k', 's', '.', 'ợ', 'ù', 'ứ', 'ă', 'ỳ', 'ẵ', 'ý', 'ó', 'ẩ', 'ọ', 'J', 'ũ', 'ữ', 'ự', 'õ', 'ỉ', 'ỏ', 'v', 'd', 'Â', 'W', 'U', 'O', 'é', 'ở', 'ỷ', '(', ')', 'ử', 'è', 'ể', 'ụ', 'ỗ', 'F', 'q', 'ẻ', 'ỡ', 'b', 'ỵ', 'Ứ', '#', 'ẽ', 'Ô', 'Ê', 'Ơ', '+', 'z', 'Ấ', 'w', 'Z', '&', 'Á', '~', 'f', 'Ạ', 'Ắ', 'j', ':', 'Ă', '<', '>', 'ẹ', '_', 'À', 'Ị', 'Ư', 'Ễ']
78
+ text = [
79
+ "190B Trần Quang Khải, Phường Tân Định, Quận 1, TP Hồ Chí Minh",
80
+ "164/2B, Quốc lộ 1A, Phường Lê Bình, Quận Cái Răng, Cần Thơ",
81
+ "Cẩm Huy, Huyện Cẩm Xuyên, Hà Tĩnh"
82
+ ]
83
+ tfm_convert = TFMLabelConverter(vocab, 'cpu')
84
+ texts, lengths = tfm_convert.encode(text, 70)
85
+ print(texts)
86
+ for text in texts:
87
+ print('Encode', text)
88
+ text = text.unsqueeze(0)
89
+ decode_text = tfm_convert.decode(text, 'char')
90
+ print('Decode', decode_text)
HybridViT/recog_flow.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Any
3
+ from collections import OrderedDict
4
+ import re
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+ from PIL import Image
10
+ from timm.models.resnetv2 import ResNetV2
11
+
12
+ from .recognizers.build_model import Model
13
+ from .module.converter import AttnLabelConverter, TFMLabelConverter
14
+ from .helper import resize
15
+
16
+ class MathRecognition(object):
17
+ def __init__(self, config, resizer):
18
+ self.args = config
19
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+ self.args["device"] = device
21
+ self.device = device
22
+ self._prepare_vocab()
23
+ self.model = self._get_model()
24
+ self.resizer=resizer
25
+
26
+ def _mapping_ckpt(self, state_dict):
27
+ new_state_dict = OrderedDict()
28
+
29
+ for name, param in state_dict.items():
30
+ if name.startswith('Transformation'):
31
+ continue
32
+ elif name.startswith('FeatureExtraction'):
33
+ new_name = name.replace('FeatureExtraction', 'featextractor.FeatureExtraction')
34
+ new_state_dict[new_name] = param
35
+ elif name.startswith('SequenceModeling'):
36
+ new_name = name.replace('SequenceModeling', 'seqmodeler.SequenceModeling')
37
+ new_state_dict[new_name] = param
38
+ elif name.startswith('Prediction'):
39
+ new_name = name.replace('Prediction', 'predicter.Prediction')
40
+ new_state_dict[new_name] = param
41
+ else:
42
+ new_state_dict[name] = param
43
+
44
+ return new_state_dict
45
+
46
+ def _get_model(self):
47
+ model = Model(self.args)
48
+ state_dict = torch.load(self.args["weight_path"], map_location='cpu')
49
+ new_state_dict = self._mapping_ckpt(state_dict)
50
+ model.load_state_dict(new_state_dict)
51
+ model=model.eval()
52
+
53
+ if self.device == 'cuda':
54
+ num_gpu = torch.cuda.device_count()
55
+ if num_gpu > 1:
56
+ model = nn.DataParallel(model).to(self.device)
57
+ else:
58
+ model.to(self.device)
59
+
60
+ return model
61
+
62
+ def _prepare_vocab(self):
63
+ with open(self.args["vocab"], 'rt') as f:
64
+ for line in f:
65
+ self.args["character"] += [line.rstrip()]
66
+ f.close()
67
+
68
+ if 'Attn' in self.args['Prediction']['name']:
69
+ self.converter = AttnLabelConverter(self.args["character"], self.device)
70
+ else:
71
+ self.converter = TFMLabelConverter(self.args["character"], self.device)
72
+
73
+ self.args["num_class"] = len(self.converter.character)
74
+
75
+ def _preprocess(self, image: Image.Image):
76
+ img_tensor = resize(self.resizer, image, self.args)
77
+ return img_tensor
78
+
79
+ def _postprocess(self, s: str):
80
+ text_reg = r'(\\(operatorname|mathrm|mathbf|mathsf|mathit|mathfrak|mathnormal)\s?\*? {.*?})'
81
+ letter = '[a-zA-Z]'
82
+ noletter = '[\W_^\d]'
83
+ names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
84
+ s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
85
+ news = s
86
+
87
+ for space in ["hspace", "vspace"]:
88
+ match = re.finditer(space + " {(.*?)}", news)
89
+ if match:
90
+ new_l = ""
91
+ last = 0
92
+ for m in match:
93
+ new_l = new_l + news[last:m.start(1)] + m.group(1).replace(" ", "")
94
+ last = m.end(1)
95
+ new_l = new_l + news[last:]
96
+ news = new_l
97
+
98
+ return news
99
+
100
+ def __call__(self, image: Image.Image, name=None, *arg: Any, **kwargs):
101
+ assert image.mode == 'RGB', 'input image must be RGB image'
102
+ with torch.no_grad():
103
+ img_tensor = self._preprocess(image).to(self.device)
104
+ text_for_pred = torch.LongTensor(1, self.args["batch_max_length"] + 1).fill_(0).to(self.device)
105
+ preds_index, _, _ = self.model(img_tensor, text_for_pred, is_train=False, is_test=True)
106
+ pred_str = self.converter.decode(preds_index, self.args.get('token_level', 'word'))[0]
107
+
108
+ pred_EOS = pred_str.find('[s]')
109
+ pred_str = pred_str[:pred_EOS]
110
+
111
+ process_str = self._postprocess(pred_str)
112
+
113
+ return process_str
HybridViT/recognizers/__init__.py ADDED
File without changes
HybridViT/recognizers/build_feat.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ..module.component.feature_extractor.clova_impl import VGG_FeatureExtractor, ResNet_FeatureExtractor
4
+
5
+
6
+ class FeatExtractorBuilder(nn.Module):
7
+ def __init__(self, flow: dict, config):
8
+ super().__init__()
9
+ self.config = config
10
+ self.flow = flow
11
+ self.feat_name = flow['Feat']
12
+
13
+ if self.feat_name != 'None':
14
+ mean_height = config['FeatureExtraction']['params'].pop('mean_height', True)
15
+
16
+ if self.feat_name == 'VGG':
17
+ self.FeatureExtraction = VGG_FeatureExtractor(**config['FeatureExtraction']['params'])
18
+ self.FeatureExtraction_output = config['FeatureExtraction']['params']['output_channel']
19
+ elif self.feat_name == 'ResNet':
20
+ self.FeatureExtraction = ResNet_FeatureExtractor(**config['FeatureExtraction']['params'])
21
+ self.FeatureExtraction_output = config['FeatureExtraction']['params']['output_channel']
22
+
23
+ if mean_height:
24
+ self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
25
+ else:
26
+ self.proj_feat = nn.Linear(self.FeatureExtraction_output*3, self.FeatureExtraction_output)
27
+ else:
28
+ if flow['Seq'] not in ['ViT', 'MS_ViT', 'MS_ViTV2', 'MS_ViTV3', 'ViG']:
29
+ raise Exception('No FeatureExtraction module specified')
30
+ else:
31
+ self.FeatureExtraction = nn.Identity()
32
+
33
+ def forward(self, input):
34
+ visual_feature = self.FeatureExtraction(input)
35
+
36
+ if self.flow['Seq'] in ['BiLSTM', 'BiLSTM_3L']:
37
+ if hasattr(self, 'AdaptiveAvgPool'):
38
+ visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, 1]
39
+ visual_feature = visual_feature.squeeze(3)
40
+ else:
41
+ visual_feature = visual_feature.permute(0, 3, 1, 2)
42
+ visual_feature = visual_feature.flatten(start_dim=-2) # [b, c, h, w] -> [b, w, c*h]
43
+ visual_feature = self.proj_feat(visual_feature)
44
+
45
+ return visual_feature
HybridViT/recognizers/build_model.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .build_feat import FeatExtractorBuilder
3
+ from .build_seq import SeqModelingBuilder
4
+ from .build_pred import PredictBuilder
5
+
6
+ class Model(nn.Module):
7
+ def __init__(self, opt):
8
+ super(Model, self).__init__()
9
+ self.opt = opt
10
+
11
+ stages = {
12
+ 'Feat': opt['FeatureExtraction']['name'],
13
+ 'Seq': opt['SequenceModeling']['name'],
14
+ 'Pred': opt['Prediction']['name'],
15
+ }
16
+ self.stages = stages
17
+ if stages['Seq'].__contains__("Vi"): assert stages['Feat'] == 'None'
18
+
19
+ """ FeatureExtraction """
20
+ self.featextractor = FeatExtractorBuilder(stages, opt)
21
+ FeatureExtraction_output = getattr(self.featextractor, 'FeatureExtraction_output', None)
22
+
23
+ """ Sequence modeling"""
24
+ self.seqmodeler = SeqModelingBuilder(stages, opt, FeatureExtraction_output)
25
+ SequenceModeling_output = getattr(self.seqmodeler, 'SequenceModeling_output', None)
26
+
27
+ """ Prediction """
28
+ self.predicter = PredictBuilder(stages, opt, SequenceModeling_output)
29
+
30
+ def forward_encoder(self, input, *args, **kwargs):
31
+ """ Feature extraction stage """
32
+ visual_feature = self.featextractor(input)
33
+ """ Sequence modeling stage """
34
+ contextual_feature, output_shape , feat_pad = self.seqmodeler(visual_feature, *args, **kwargs)
35
+ return contextual_feature, output_shape, feat_pad
36
+
37
+ def forward_decoder(self, contextual_feature, text, is_train=True,
38
+ is_test=False, rtl_text=None
39
+ ):
40
+ """ Prediction stage """
41
+ prediction, logits, decoder_attn, addition_outputs = self.predicter(contextual_feature, text, is_train, is_test, rtl_text)
42
+
43
+ return prediction, logits, decoder_attn, addition_outputs
44
+
45
+ def forward(self, input, text, is_train=True, is_test=False, rtl_text=None):
46
+ contextual_feature, output_shape, feat_pad = self.forward_encoder(input)
47
+ prediction, logits, decoder_attn, addition_outputs = self.forward_decoder(
48
+ contextual_feature, text=text, is_train=is_train, is_test=is_test, rtl_text=rtl_text
49
+ )
50
+
51
+ if decoder_attn is not None and output_shape is not None:
52
+ if self.stages['Pred'] == 'Attn' and self.stages['Seq'] == 'ViT':
53
+ decoder_attn = decoder_attn[:, 1:]
54
+ decoder_attn = decoder_attn.reshape(-1, output_shape[0], output_shape[1])
55
+
56
+ addition_outputs.update(
57
+ {
58
+ 'decoder_attn': decoder_attn,
59
+ 'feat_width': output_shape[0],
60
+ 'feat_height': output_shape[1],
61
+ 'feat_pad': feat_pad,
62
+ }
63
+ )
64
+
65
+ return prediction, logits, addition_outputs
66
+
67
+ if __name__ == '__main__':
68
+ import yaml
69
+ import torch
70
+
71
+ with open('/media/huynhtruc0309/DATA/Math_Expression/my_source/Math_Recognition/config/train/paper_experiments/report_paper/best_model_noaugment/experiment_1806.yaml', 'r') as f:
72
+ config = yaml.safe_load(f)
73
+
74
+ config['num_class'] = 499
75
+ config['device'] = 'cpu'
76
+ model = Model(config)
77
+
78
+ a = torch.rand(1, 1, 32, 224)
79
+
80
+ output = model(a)
81
+
82
+ print('pred', output[0].shape)
HybridViT/recognizers/build_pred.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ..module.component.prediction_head import (Attention,
4
+ AttentionV2,
5
+ TransformerPrediction
6
+ )
7
+
8
+
9
+ class PredictBuilder(nn.Module):
10
+ def __init__(self, flow, config, SequenceModeling_output):
11
+ super().__init__()
12
+ self.flow =flow
13
+ self.config=config
14
+ if flow['Pred'] == 'CTC':
15
+ self.Prediction = nn.Linear(SequenceModeling_output, config['num_class'])
16
+
17
+ elif flow['Pred'] == 'Attn':
18
+ config['Prediction']['params']['num_classes'] = config['num_class']
19
+ config['Prediction']['params']['device'] = config['device']
20
+ self.Prediction = Attention(
21
+ **config['Prediction']['params']
22
+ )
23
+ elif flow['Pred'] == 'Attnv2':
24
+ config['Prediction']['params']['num_classes'] = config['num_class']
25
+ config['Prediction']['params']['device'] = config['device']
26
+ self.Prediction = AttentionV2(
27
+ **config['Prediction']['params']
28
+ )
29
+ elif flow['Pred'] == 'Multistage_Attn':
30
+ config['Prediction']['params']['num_classes'] = config['num_class']
31
+ config['Prediction']['params']['device'] = config['device']
32
+ self.Prediction = AttentionV2(
33
+ **config['Prediction']['params']
34
+ )
35
+ elif flow['Pred'] == 'TFM':
36
+ config['Prediction']['params']['num_classes'] = config['num_class']
37
+ config['Prediction']['params']['device'] = config['device']
38
+ self.Prediction = TransformerPrediction(
39
+ **config['Prediction']['params']
40
+ )
41
+ else:
42
+ raise ValueError('Prediction name is not suppported')
43
+
44
+ def forward(self, contextual_feature, text, is_train=True, is_test=False, rtl_text=None):
45
+ beam_size = self.config.get('beam_size', 1)
46
+
47
+ addition_outputs = {}
48
+ decoder_attn = None
49
+
50
+ if self.flow['Pred'] == 'CTC':
51
+ prediction = self.Prediction(contextual_feature.contiguous())
52
+
53
+ elif self.flow['Pred'] in ['Attn', 'Attnv2']:
54
+ prediction, logits, decoder_attn = self.Prediction(beam_size, contextual_feature.contiguous(), text, is_train=is_train,
55
+ is_test=is_test, batch_max_length=self.config['batch_max_length'])
56
+
57
+ elif self.flow['Pred'] == 'TFM':
58
+ prediction, logits = self.Prediction(beam_size, contextual_feature.contiguous(), text, is_test)
59
+ self.Prediction.reset_beam()
60
+
61
+ return prediction, logits, decoder_attn, addition_outputs
HybridViT/recognizers/build_seq.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from ..module.component.seq_modeling import BidirectionalLSTM, create_vit_modeling
3
+ from ..module.component.seq_modeling.bilstm import BiLSTM_Seq_Modeling
4
+ from ..module.component.common import GatedSum
5
+ from ..module.component.common import PositionalEncoding2D, PositionalEncoding1D
6
+
7
+
8
+ class SeqModelingBuilder(nn.Module):
9
+ def __init__(self, flow: dict, config, FeatureExtraction_output):
10
+ super().__init__()
11
+ self.config = config
12
+ self.flow = flow
13
+
14
+ if flow['Seq'] == 'BiLSTM':
15
+ hidden_size = config['SequenceModeling']['params']['hidden_size']
16
+ use_pos_enc = config['SequenceModeling']['params'].get('pos_enc', False)
17
+
18
+ if use_pos_enc:
19
+ self.image_positional_encoder = PositionalEncoding1D(hidden_size)
20
+ self.gated = GatedSum(hidden_size)
21
+
22
+ self.SequenceModeling = nn.Sequential(
23
+ BidirectionalLSTM(FeatureExtraction_output, hidden_size, hidden_size),
24
+ BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
25
+
26
+ self.SequenceModeling_output = hidden_size
27
+
28
+ elif flow['Seq'] == 'BiLSTM_3L':
29
+ hidden_size = config['SequenceModeling']['params']['hidden_size']
30
+ self.SequenceModeling = BiLSTM_Seq_Modeling(3, FeatureExtraction_output, hidden_size, hidden_size)
31
+
32
+ self.SequenceModeling_output = hidden_size
33
+
34
+ elif flow['Seq'] == 'ViT':
35
+ assert config['max_dimension'] is not None, "ViT encoder require exact height or max height and max width"
36
+ self.SequenceModeling = create_vit_modeling(config)
37
+ else:
38
+ print('No SequenceModeling module specified')
39
+ if flow['Pred'] == 'TFM':
40
+ self.image_positional_encoder = PositionalEncoding2D(FeatureExtraction_output)
41
+
42
+ self.SequenceModeling_output = FeatureExtraction_output
43
+
44
+ def forward(self, visual_feature, *args, **kwargs):
45
+ output_shape = None
46
+ pad_info = None
47
+
48
+ if self.flow['Seq'] in ['BiLSTM', 'BiLSTM_3L']:
49
+ contextual_feature = self.SequenceModeling(visual_feature)
50
+
51
+ if hasattr(self, 'image_positional_encoder'):
52
+ assert len(contextual_feature.shape) == 3
53
+ contextual_feature_1 = self.image_positional_encoder(visual_feature.permute(1, 0, 2))
54
+ contextual_feature_1 = contextual_feature_1.permute(1, 0, 2)
55
+ contextual_feature = self.gated(contextual_feature_1, contextual_feature)
56
+
57
+ elif self.flow['Seq'] == 'ViT':
58
+ contextual_feature, pad_info, _ = self.SequenceModeling(visual_feature)
59
+
60
+ return contextual_feature, output_shape, pad_info
HybridViT/resizer.py ADDED
File without changes
ScanSSD/IOU_lib/BoundingBox.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .iou_utils import *
2
+
3
+ class BoundingBox:
4
+ def __init__(self,
5
+ imageName,
6
+ classId,
7
+ x,
8
+ y,
9
+ w,
10
+ h,
11
+ typeCoordinates=CoordinatesType.Absolute,
12
+ imgSize=None,
13
+ bbType=BBType.GroundTruth,
14
+ classConfidence=None,
15
+ format=BBFormat.XYWH):
16
+ """Constructor.
17
+ Args:
18
+ imageName: String representing the image name.
19
+ classId: String value representing class id.
20
+ x: Float value representing the X upper-left coordinate of the bounding box.
21
+ y: Float value representing the Y upper-left coordinate of the bounding box.
22
+ w: Float value representing the width bounding box.
23
+ h: Float value representing the height bounding box.
24
+ typeCoordinates: (optional) Enum (Relative or Absolute) represents if the bounding box
25
+ coordinates (x,y,w,h) are absolute or relative to size of the image. Default:'Absolute'.
26
+ imgSize: (optional) 2D vector (width, height)=>(int, int) represents the size of the
27
+ image of the bounding box. If typeCoordinates is 'Relative', imgSize is required.
28
+ bbType: (optional) Enum (Groundtruth or Detection) identifies if the bounding box
29
+ represents a ground truth or a detection. If it is a detection, the classConfidence has
30
+ to be informed.
31
+ classConfidence: (optional) Float value representing the confidence of the detected
32
+ class. If detectionType is Detection, classConfidence needs to be informed.
33
+ format: (optional) Enum (BBFormat.XYWH or BBFormat.XYX2Y2) indicating the format of the
34
+ coordinates of the bounding boxes. BBFormat.XYWH: <left> <top> <width> <height>
35
+ BBFormat.XYX2Y2: <left> <top> <right> <bottom>.
36
+ """
37
+ self._imageName = imageName
38
+ self._typeCoordinates = typeCoordinates
39
+ if typeCoordinates == CoordinatesType.Relative and imgSize is None:
40
+ raise IOError(
41
+ 'Parameter \'imgSize\' is required. It is necessary to inform the image size.')
42
+ if bbType == BBType.Detected and classConfidence is None:
43
+ raise IOError(
44
+ 'For bbType=\'Detection\', it is necessary to inform the classConfidence value.')
45
+ # if classConfidence != None and (classConfidence < 0 or classConfidence > 1):
46
+ # raise IOError('classConfidence value must be a real value between 0 and 1. Value: %f' %
47
+ # classConfidence)
48
+
49
+ self._classConfidence = classConfidence
50
+ self._bbType = bbType
51
+ self._classId = classId
52
+ self._format = format
53
+
54
+ # If relative coordinates, convert to absolute values
55
+ # For relative coords: (x,y,w,h)=(X_center/img_width , Y_center/img_height)
56
+ if (typeCoordinates == CoordinatesType.Relative):
57
+ (self._x, self._y, self._w, self._h) = convertToAbsoluteValues(imgSize, (x, y, w, h))
58
+ self._width_img = imgSize[0]
59
+ self._height_img = imgSize[1]
60
+ if format == BBFormat.XYWH:
61
+ self._x2 = self._w
62
+ self._y2 = self._h
63
+ self._w = self._x2 - self._x
64
+ self._h = self._y2 - self._y
65
+ else:
66
+ raise IOError(
67
+ 'For relative coordinates, the format must be XYWH (x,y,width,height)')
68
+ # For absolute coords: (x,y,w,h)=real bb coords
69
+ else:
70
+ self._x = x
71
+ self._y = y
72
+ if format == BBFormat.XYWH:
73
+ self._w = w
74
+ self._h = h
75
+ self._x2 = self._x + self._w
76
+ self._y2 = self._y + self._h
77
+ else: # format == BBFormat.XYX2Y2: <left> <top> <right> <bottom>.
78
+ self._x2 = w
79
+ self._y2 = h
80
+ self._w = self._x2 - self._x
81
+ self._h = self._y2 - self._y
82
+ if imgSize is None:
83
+ self._width_img = None
84
+ self._height_img = None
85
+ else:
86
+ self._width_img = imgSize[0]
87
+ self._height_img = imgSize[1]
88
+
89
+ def __str__(self):
90
+ return "{} {} {} {} {}".format(self.getImageName(), self._x, self._y, self._w, self._h)
91
+
92
+ def getAbsoluteBoundingBox(self, format=BBFormat.XYWH):
93
+ if format == BBFormat.XYWH:
94
+ return (self._x, self._y, self._w, self._h)
95
+ elif format == BBFormat.XYX2Y2:
96
+ return (self._x, self._y, self._x2, self._y2)
97
+
98
+ def getRelativeBoundingBox(self, imgSize=None):
99
+ if imgSize is None and self._width_img is None and self._height_img is None:
100
+ raise IOError(
101
+ 'Parameter \'imgSize\' is required. It is necessary to inform the image size.')
102
+ if imgSize is None:
103
+ return convertToRelativeValues((imgSize[0], imgSize[1]),
104
+ (self._x, self._y, self._w, self._h))
105
+ else:
106
+ return convertToRelativeValues((self._width_img, self._height_img),
107
+ (self._x, self._y, self._w, self._h))
108
+
109
+ def getImageName(self):
110
+ return self._imageName
111
+
112
+ def getConfidence(self):
113
+ return self._classConfidence
114
+
115
+ def getFormat(self):
116
+ return self._format
117
+
118
+ def getClassId(self):
119
+ return self._classId
120
+
121
+ def getImageSize(self):
122
+ return (self._width_img, self._height_img)
123
+
124
+ def getCoordinatesType(self):
125
+ return self._typeCoordinates
126
+
127
+ def getBBType(self):
128
+ return self._bbType
129
+
130
+ @staticmethod
131
+ def compare(det1, det2):
132
+ det1BB = det1.getAbsoluteBoundingBox()
133
+ det1ImgSize = det1.getImageSize()
134
+ det2BB = det2.getAbsoluteBoundingBox()
135
+ det2ImgSize = det2.getImageSize()
136
+
137
+ if det1.getClassId() == det2.getClassId() and \
138
+ det1.classConfidence == det2.classConfidenc() and \
139
+ det1BB[0] == det2BB[0] and \
140
+ det1BB[1] == det2BB[1] and \
141
+ det1BB[2] == det2BB[2] and \
142
+ det1BB[3] == det2BB[3] and \
143
+ det1ImgSize[0] == det1ImgSize[0] and \
144
+ det2ImgSize[1] == det2ImgSize[1]:
145
+ return True
146
+ return False
147
+
148
+ @staticmethod
149
+ def clone(boundingBox):
150
+ absBB = boundingBox.getAbsoluteBoundingBox(format=BBFormat.XYWH)
151
+ # return (self._x,self._y,self._x2,self._y2)
152
+ newBoundingBox = BoundingBox(
153
+ boundingBox.getImageName(),
154
+ boundingBox.getClassId(),
155
+ absBB[0],
156
+ absBB[1],
157
+ absBB[2],
158
+ absBB[3],
159
+ typeCoordinates=boundingBox.getCoordinatesType(),
160
+ imgSize=boundingBox.getImageSize(),
161
+ bbType=boundingBox.getBBType(),
162
+ classConfidence=boundingBox.getConfidence(),
163
+ format=BBFormat.XYWH)
164
+ return newBoundingBox
ScanSSD/IOU_lib/Evaluator.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################################################
2
+ # #
3
+ # Evaluator class: Implements the most popular metrics for object detection #
4
+ # #
5
+ # Developed by: Rafael Padilla (rafael.padilla@smt.ufrj.br) #
6
+ # SMT - Signal Multimedia and Telecommunications Lab #
7
+ # COPPE - Universidade Federal do Rio de Janeiro #
8
+ # Last modification: Oct 9th 2018 #
9
+ ###########################################################################################
10
+
11
+ import os
12
+ import sys
13
+ from collections import Counter
14
+
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+
18
+ from .BoundingBox import *
19
+ from .iou_utils import *
20
+
21
+
22
+ class Evaluator:
23
+
24
+ # For each detections, calculate IOU with reference
25
+ @staticmethod
26
+ def _getAllIOUs(reference, detections):
27
+ ret = []
28
+ bbReference = reference.getAbsoluteBoundingBox(BBFormat.XYX2Y2)
29
+ # img = np.zeros((200,200,3), np.uint8)
30
+ for d in detections:
31
+ bb = d.getAbsoluteBoundingBox(BBFormat.XYX2Y2)
32
+ iou = Evaluator.iou(bbReference, bb)
33
+ # Show blank image with the bounding boxes
34
+ # img = add_bb_into_image(img, d, color=(255,0,0), thickness=2, label=None)
35
+ # img = add_bb_into_image(img, reference, color=(0,255,0), thickness=2, label=None)
36
+ ret.append((iou, reference, d)) # iou, reference, detection
37
+ # cv2.imshow("comparing",img)
38
+ # cv2.waitKey(0)
39
+ # cv2.destroyWindow("comparing")
40
+ return sorted(ret, key=lambda i: i[0], reverse=True) # sort by iou (from highest to lowest)
41
+
42
+ @staticmethod
43
+ def iou(boxA, boxB):
44
+ # if boxes dont intersect
45
+ if Evaluator._boxesIntersect(boxA, boxB) is False:
46
+ return 0
47
+ interArea = Evaluator._getIntersectionArea(boxA, boxB)
48
+ union = Evaluator._getUnionAreas(boxA, boxB, interArea=interArea)
49
+ # intersection over union
50
+ iou = interArea / union
51
+ assert iou >= 0
52
+ return iou
53
+
54
+ # boxA = (Ax1,Ay1,Ax2,Ay2)
55
+ # boxB = (Bx1,By1,Bx2,By2)
56
+ @staticmethod
57
+ def _boxesIntersect(boxA, boxB):
58
+ if boxA[0] > boxB[2]:
59
+ return False # boxA is right of boxB
60
+ if boxB[0] > boxA[2]:
61
+ return False # boxA is left of boxB
62
+ if boxA[3] < boxB[1]:
63
+ return False # boxA is above boxB
64
+ if boxA[1] > boxB[3]:
65
+ return False # boxA is below boxB
66
+ return True
67
+
68
+ @staticmethod
69
+ def _getIntersectionArea(boxA, boxB):
70
+ xA = max(boxA[0], boxB[0])
71
+ yA = max(boxA[1], boxB[1])
72
+ xB = min(boxA[2], boxB[2])
73
+ yB = min(boxA[3], boxB[3])
74
+ # intersection area
75
+ return (xB - xA + 1) * (yB - yA + 1)
76
+
77
+ @staticmethod
78
+ def _getUnionAreas(boxA, boxB, interArea=None):
79
+ area_A = Evaluator._getArea(boxA)
80
+ area_B = Evaluator._getArea(boxB)
81
+ if interArea is None:
82
+ interArea = Evaluator._getIntersectionArea(boxA, boxB)
83
+ return float(area_A + area_B - interArea)
84
+
85
+ @staticmethod
86
+ def _getArea(box):
87
+ return (box[2] - box[0] + 1) * (box[3] - box[1] + 1)
ScanSSD/IOU_lib/IOUevaluater.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from zipfile import ZipFile
2
+ import os
3
+ from .Evaluator import *
4
+ from utils import *
5
+ import copy
6
+ import argparse
7
+ import sys
8
+ import ntpath
9
+ #import cStringIO
10
+ from io import BytesIO
11
+ import shutil
12
+
13
+
14
+ def read_file(filename, bboxes, flag):
15
+ '''
16
+ Parses the input .csv file into map where key as page number and value as a list of bounding box objects
17
+ corresponding to each math region in the file.
18
+ :param filename: .csv file containing math regions
19
+ :param bboxes: Map<page_num, List<bboxes>>
20
+ :return:
21
+ '''
22
+ fh1 = open(filename, "r")
23
+ prev_page = -1
24
+ counter = 1
25
+ for line in fh1:
26
+ line = line.replace("\n", "")
27
+ if line.replace(' ', '') == '':
28
+ continue
29
+ splitLine = line.split(",")
30
+ idClass = float(splitLine[0])
31
+ if prev_page == -1:
32
+ prev_page = idClass
33
+ else:
34
+ if idClass != prev_page:
35
+ counter = 1
36
+ prev_page = idClass
37
+ x = float(splitLine[1])
38
+ y = float(splitLine[2])
39
+ x2 = float(splitLine[3])
40
+ y2 = float(splitLine[4])
41
+ bb = BoundingBox(
42
+ flag+"_"+str(counter),
43
+ 1,
44
+ x,
45
+ y,
46
+ x2,
47
+ y2,
48
+ CoordinatesType.Absolute, (200, 200),
49
+ BBType.GroundTruth,
50
+ format=BBFormat.XYX2Y2)
51
+ counter += 1
52
+ #print(counter)
53
+ if idClass not in bboxes:
54
+ bboxes[idClass] = []
55
+ bboxes[idClass].append(bb)
56
+
57
+ fh1.close()
58
+
59
+
60
+ def extract_zipfile(zip_filename, target_dir):
61
+ '''
62
+ Extract zip file into the target directory
63
+ :param zip_filename: full-file-path of the zip-file
64
+ :param target_dir: target-dir to extract contents of zip-file
65
+ :return:
66
+ '''
67
+ with ZipFile(zip_filename, 'r') as zip:
68
+ # extracting all the files
69
+ print('Extracting all the files now...')
70
+ zip.extractall(target_dir)
71
+ print('Done!')
72
+
73
+
74
+ def create_doc_bboxes_map(dir_path,flag):
75
+ '''
76
+ Reads all files recursively in directory path and and returns a map containing bboxes for each page in each math
77
+ file in directory.
78
+ :param dir_path: full directory path containing math files
79
+ :return: Map<PDF_name, Map<Page_number, List<BBoxes>>>
80
+ '''
81
+ pdf_bboxes_map = {}
82
+
83
+ for filename in os.listdir(dir_path):
84
+ full_filepath = os.path.join(dir_path, filename)
85
+ filename_key = os.path.splitext(os.path.basename(full_filepath))[0]
86
+ #print(full_filepath)
87
+ if (full_filepath.startswith(".")) or (not (full_filepath.endswith(".csv") or full_filepath.endswith(".math"))):
88
+ continue
89
+ bboxes_map = {}
90
+
91
+ if os.path.isdir(full_filepath):
92
+ continue
93
+
94
+ try:
95
+ read_file(full_filepath, bboxes_map,flag)
96
+ except Exception as e:
97
+ print('exception occurred in reading file',full_filepath, str(e))
98
+
99
+ #if len(bboxes_map)==0:
100
+ # raise ValueError("Empty ground truths file or not in valid format")
101
+ pdf_bboxes_map[filename_key] = copy.deepcopy(bboxes_map)
102
+
103
+ return pdf_bboxes_map
104
+
105
+ def unique_values(input_dict):
106
+ #return ground truth boxes that have same det boxes
107
+ pred_list=[]
108
+ repair_keys=[]
109
+ for value in input_dict.values():
110
+ if value[1] in pred_list: #preds.append(value)
111
+ gts=[k for k,v in input_dict.items() if v[1] == value[1]]
112
+ #print('pair length',len(gts))
113
+ repair_keys.append(gts)
114
+ pred_list.append(value[1])
115
+
116
+ return repair_keys
117
+
118
+ def generate_validpairs(pairs):
119
+ newpairs=[]
120
+ for pair in pairs:
121
+ if len(pair)>2:
122
+ for i in range(len(pair)-1):
123
+ newpair=(pair[i],pair[i+1])
124
+ if newpair not in newpairs:newpairs.append(newpair)
125
+
126
+ elif pair not in newpairs: newpairs.append(pair)
127
+ return newpairs
128
+
129
+ def fix_preds(input_dict,keyPairs,thre):
130
+
131
+ validPairs=generate_validpairs(keyPairs)
132
+
133
+ for pair in validPairs:
134
+ #check if both pair exists"
135
+ if pair[0] not in list(input_dict.keys()) or pair[1] not in list(input_dict.keys()):
136
+ continue
137
+ val0=input_dict[pair[0]][0]
138
+ val1=input_dict[pair[1]][0]
139
+ if val0>=val1: #change prediction for second pair
140
+ values=input_dict[pair[1]]
141
+ newprob=values[2][1]
142
+ if newprob<thre:
143
+ del input_dict[pair[1]]
144
+ continue
145
+ #update dict
146
+ input_dict[pair[1]]=newprob,values[3][1],values[2][1:],values[3][1:]
147
+
148
+ if val1>val0: #change prediction for first pair
149
+ values=input_dict[pair[0]]
150
+ newprob=values[2][1]
151
+ if newprob<thre:
152
+ del input_dict[pair[0]]
153
+ continue
154
+ #update dict
155
+ input_dict[pair[0]]=newprob,values[3][1],values[2][1:],values[3][1:]
156
+
157
+ return input_dict
158
+
159
+ def find_uni_pred(input_dict,thre):
160
+ # check if it is unique
161
+ pairs=unique_values(input_dict)
162
+ if pairs==[]:
163
+ return input_dict
164
+
165
+ while pairs:
166
+ output_dict=fix_preds(input_dict,pairs,thre)
167
+ pairs=unique_values(output_dict)
168
+
169
+ return output_dict
170
+
171
+
172
+ def count_true_box(pred_dict,thre):
173
+ #remove predictions below thre from dict
174
+ for key in list(pred_dict.keys()):
175
+ max_prob=pred_dict[key][0]
176
+ if max_prob<thre:
177
+ del pred_dict[key]
178
+
179
+ #check for 101 mapping
180
+ final_dict=find_uni_pred(pred_dict,thre)
181
+
182
+ count=len(final_dict.keys())
183
+ return count,final_dict
184
+
185
+
186
+ def IoU_page_bboxes(gt_page_bboxes_map, det_page_bboxes_map, pdf_name, outdir=None):
187
+ '''
188
+ Takes two maps containing page level bounding boxes for ground truth and detections for same PDF filename and
189
+ computes IoU for each BBox in a page in GT against all BBoxes in the same page in detections and returns them in
190
+ decreasing value of IoU. In this way it computes IoU for all pages in map.
191
+
192
+ :param gt_page_bboxes_map: Map<pageNum, List<bboxes>> for ground truth bboxes
193
+ :param det_page_bboxes_map: Map<pageNum, List<bboxes>> for detection bboxes
194
+ :return:
195
+ '''
196
+ evaluator = Evaluator()
197
+
198
+ correct_pred_coarse=0
199
+ correct_pred_fine=0
200
+
201
+ pdf_gt_boxes=0
202
+ pdf_det_boxes=0
203
+
204
+ coarse_keys = {}
205
+ fine_keys = {}
206
+
207
+ for page_num in gt_page_bboxes_map:
208
+ if page_num not in det_page_bboxes_map:
209
+ print('Detections not found for page', str(page_num + 1), ' in', pdf_name)
210
+ continue
211
+ gt_boxes = gt_page_bboxes_map[page_num]
212
+ det_boxes = det_page_bboxes_map[page_num]
213
+
214
+ pdf_gt_boxes+=len(gt_boxes)
215
+ pdf_det_boxes+=len(det_boxes)
216
+
217
+ pred_dict={}
218
+ for gt_box in gt_boxes:
219
+ ious = evaluator._getAllIOUs(gt_box, det_boxes)
220
+ preds=[]
221
+ labels=[]
222
+ for i in range(len(ious)):
223
+ preds.append(round(ious[i][0],2))
224
+ labels.append(ious[i][2].getImageName())
225
+
226
+ pred_dict[gt_box.getImageName()]=preds[0],labels[0],preds,labels
227
+
228
+ coarse,coarse_dict=count_true_box(copy.deepcopy(pred_dict),0.5)
229
+ fine,fine_dict=count_true_box(copy.deepcopy(pred_dict),0.75)
230
+
231
+ coarse_keys[page_num] = coarse_dict.keys()
232
+ fine_keys[page_num] = fine_dict.keys()
233
+
234
+ #count correct preds for coarse 0.5 and fine 0.75 in one page
235
+ correct_pred_coarse= correct_pred_coarse+coarse
236
+ correct_pred_fine= correct_pred_fine+fine
237
+ #write iou per page
238
+ if outdir:
239
+ out_file = open(os.path.join(outdir,pdf_name.split(".csv")[0]+"_"+str(page_num)+"_eval.txt"), "w")
240
+ out_file.write('#page num '+str(page_num)+", gt_box:"+str(len(gt_boxes))+
241
+ ", pred_box:"+str(len(det_boxes))+"\n")
242
+ out_file.write('\n')
243
+ out_file.write('#COARSE DETECTION (iou>0.5):\n#number of correct prediction:'+ str(coarse)+ '\n#correctly detected:'+
244
+ str(list(coarse_dict.keys()))+'\n')
245
+ out_file.write('\n')
246
+ out_file.write('#FINE DETECTION (iou>0.75):\n#number of correct prediction:'+ str(fine)+ '\n#correctly detected:'+
247
+ str(list(fine_dict.keys()))+'\n')
248
+ out_file.write('\n')
249
+ out_file.write('#Sorted IOU scores for each GT box:\n')
250
+ for gt_box in gt_boxes:
251
+ ious = evaluator._getAllIOUs(gt_box, det_boxes)
252
+ out_file.write(gt_box.getImageName()+",")
253
+ for i in range(len(ious)-1):
254
+ out_file.write("("+str(round(ious[i][0],2))+" "+ str(ious[i][2].getImageName())+"),")
255
+ out_file.write( "("+str(round(ious[-1][0],2))+" "+ str(ious[-1][2].getImageName())+")\n" )
256
+ out_file.close()
257
+
258
+ return correct_pred_coarse, correct_pred_fine, pdf_gt_boxes, pdf_det_boxes, coarse_keys, fine_keys
259
+
260
+ def count_box(input_dict):
261
+ count=0
262
+ for pdf in input_dict.values():
263
+ for page in pdf.values():
264
+ count+=len(page)
265
+
266
+ return count
267
+
268
+ # Zip every uploading files
269
+ def archive_iou_txt(username, task_id, sub_id,userpath):
270
+
271
+ inputdir=os.path.join(userpath,'iouEval_stats')
272
+
273
+ if not os.path.exists(inputdir):
274
+ print('No txt file is generated for IOU evaluation')
275
+ pass
276
+
277
+ dest_uploader = 'IOU_stats_archive'
278
+ dest_uploader = os.path.join(userpath, dest_uploader)
279
+
280
+ if not os.path.exists(dest_uploader):
281
+ os.makedirs(dest_uploader)
282
+
283
+ zip_file_name = '/' + task_id + '_' + sub_id
284
+ shutil.make_archive(dest_uploader + zip_file_name, 'zip', inputdir)
285
+
286
+ # return '/media/' + dest_uploader
287
+
288
+ def write_html(gtFile,resultsFile,info,scores,destFile):
289
+
290
+ destFile.write('<html>')
291
+ destFile.write('<head><link rel="stylesheet" href="//maxcdn.bootstrapcdn.com/font-awesome/4.3.0/css/font-awesome.min.css"><link href="/static/css/bootstrap.min.css" rel="stylesheet"></head>')
292
+ destFile.write('<body>')
293
+ #writeCSS(destFile)
294
+ destFile.write ("<blockquote><b>CROHME 2019</b> <h1> Formula Detection Results ( TASK 3 )</h1><hr>")
295
+ destFile.write("<b>Submitted Files</b><ul><li><b>Output:</b> "+ ntpath.basename(resultsFile) +"</li>")
296
+ destFile.write ("<li><b>Ground-truth:</b> " + ntpath.basename(gtFile) + "</li></ul>")
297
+ if info['allGTbox'] == 0:
298
+ sys.stderr.write("Error : no sample in this GT list !\n")
299
+ exit(-1)
300
+ #all detection and gt boxes
301
+ destFile.write ("<p><b> Number of ground truth bounding boxes: </b>" + str(info['allGTbox']) + "<br /><b> Number of detected bounding boxes: </b>" + str(info['allDet']))
302
+ destFile.write ("<hr>")
303
+ #coarse results
304
+ destFile.write ("<p><b> **** Coarse Detection Results (IOU>0.5) ****</b><br />")
305
+ destFile.write ("<ul><li><b>"+str(scores['coarse_f']) + "</b> F-score</li>")
306
+ destFile.write ("<li>"+str(scores['coarse_pre']) + " Precision</li>")
307
+ destFile.write ("<li>"+str(scores['coarse_rec']) + " Recall</li></ul>")
308
+ destFile.write ("<b>" + str(info['correctDet_c']) + "</b> Number of correctly detected bounding boxes</p>")
309
+ destFile.write ("<hr>")
310
+ #fine results
311
+ destFile.write ("<p><b> **** Fine Detection Results (IOU>0.75) ****</b><br />")
312
+ destFile.write ("<ul><li><b>"+str(scores['fine_f']) + "</b> F-score</li>")
313
+ destFile.write ("<li>"+str(scores['fine_pre']) + " Precision</li>")
314
+ destFile.write ("<li>"+str(scores['fine_rec']) + " Recall</li></ul>")
315
+ destFile.write ("<b>" + str(info['correctDet_f']) + "</b> Number of correctly detected bounding boxes</p>")
316
+ destFile.write ("<hr>")
317
+ destFile.write('</body>')
318
+ destFile.write('</html>')
319
+
320
+ def pre_rec_calculate(count):
321
+
322
+ if count['allDet']==0:
323
+ print ('No detection boxes found')
324
+ scores={'fine_f':0,'coarse_f':0}
325
+ else:
326
+ pre_f=count['correctDet_f']/float(count['allDet'])
327
+ recall_f=count['correctDet_f']/float(count['allGTbox'])
328
+ if pre_f==0 and recall_f ==0:
329
+ f_f=0
330
+ else:
331
+ f_f=2*(pre_f*recall_f)/float(pre_f+recall_f)
332
+
333
+ pre_c=count['correctDet_c']/float(count['allDet'])
334
+ recall_c=count['correctDet_c']/float(count['allGTbox'])
335
+ if pre_c==0 and recall_c==0:
336
+ f_c=0
337
+ else:
338
+ f_c=2*(pre_c*recall_c)/float(pre_c+recall_c)
339
+ print('')
340
+ print('**** coarse result : threshold: 0.5 *****')
341
+ print(' f =',f_c,' precision =',pre_c,' recall =',recall_c)
342
+ print('')
343
+ print('**** fine result : threshold: 0.75 *****')
344
+ print(' f =',f_f,' precision =',pre_f,' recall =',recall_f)
345
+
346
+ scores={'fine_f':round(f_f,4),'fine_pre':round(pre_f,4),'fine_rec':round(recall_f,4),
347
+ 'coarse_f':round(f_c,4),'coarse_pre':round(pre_c,4),'coarse_rec':round(recall_c,4)}
348
+ return scores
349
+
350
+ def IOUeval(ground_truth, detections, outdir=None): #,
351
+
352
+ keys=['allGTbox','correctDet_c','correctDet_f','allDet']
353
+ info=dict.fromkeys(keys,0)
354
+
355
+ gt_file_name = ground_truth
356
+ det_file_name = detections
357
+
358
+ #TODO : Mahshad change it to user directory
359
+ if outdir:
360
+ #outdir='IOU_eval_stats'
361
+ if os.path.exists(outdir):
362
+ shutil.rmtree(outdir)
363
+ os.makedirs(outdir)
364
+
365
+ gt_pdfs_bboxes_map = create_doc_bboxes_map(gt_file_name,'gt')
366
+ det_pdfs_bboxes_map = create_doc_bboxes_map(det_file_name,'det')
367
+ #count boxes
368
+ all_gtbox=count_box(gt_pdfs_bboxes_map)
369
+ all_detbox=count_box(det_pdfs_bboxes_map)
370
+
371
+
372
+ info['allGTbox']=all_gtbox
373
+ info['allDet']=all_detbox
374
+
375
+ pdf_gt_bbs = 0
376
+ pdf_dt_bbs = 0
377
+ pdf_info = {}
378
+ pdf_calcs = {}
379
+
380
+ detailed_detections = {}
381
+
382
+ for pdf_name in gt_pdfs_bboxes_map:
383
+ if pdf_name not in det_pdfs_bboxes_map:
384
+ print('Detections not found for ',pdf_name)
385
+ continue
386
+
387
+ det_page_bboxes_map = det_pdfs_bboxes_map[pdf_name]
388
+ gt_page_bboxes_map = gt_pdfs_bboxes_map[pdf_name]
389
+
390
+ coarse_true_det,fine_true_det,pdf_gt_boxes,pdf_det_boxes,coarse_keys,fine_keys=\
391
+ IoU_page_bboxes(gt_page_bboxes_map, det_page_bboxes_map, pdf_name,outdir)
392
+ info['correctDet_c']=info['correctDet_c']+coarse_true_det
393
+ info['correctDet_f']=info['correctDet_f']+fine_true_det
394
+
395
+ pdf_info['correctDet_c']=coarse_true_det
396
+ pdf_info['correctDet_f']=fine_true_det
397
+ pdf_info['allGTbox']=pdf_gt_boxes
398
+ pdf_info['allDet']=pdf_det_boxes
399
+
400
+ print('For pdf: ', pdf_name)
401
+ pdf_calcs[pdf_name]=pre_rec_calculate(pdf_info)
402
+ detailed_detections[pdf_name] = [coarse_keys, fine_keys]
403
+ #print('Pdf score:',pdf_name, " --> ", pre_rec_calculate(pdf_info))
404
+
405
+ print('\n')
406
+ print(info)
407
+ scores=pre_rec_calculate(info)
408
+
409
+ print('\n PDF Level \n')
410
+ #print(pdf_calcs)
411
+
412
+ #{'fine_f': 0.7843, 'fine_pre': 0.7774, 'fine_rec': 0.7914, 'coarse_f': 0.902, 'coarse_pre': 0.894, 'coarse_rec': 0.9101}
413
+ for pdf_name in pdf_calcs:
414
+ print(pdf_name,'\t', pdf_calcs[pdf_name]['coarse_f'],'\t',pdf_calcs[pdf_name]['fine_f'])
415
+
416
+ #return corase and fine F-scores
417
+ return scores['coarse_f'],scores['fine_f'], detailed_detections
418
+
419
+ if __name__=='__main__':
420
+
421
+ parser = argparse.ArgumentParser()
422
+ parser.add_argument("--detections", type=str, required=True, help="detections file path")
423
+ parser.add_argument("--ground_truth", type=str, required=True, help="ground_truth file path")
424
+ args = parser.parse_args()
425
+
426
+ gt_file_name = args.ground_truth
427
+ det_file_name = args.detections
428
+
429
+
430
+ c_f,f_f=IOUeval(gt_file_name,det_file_name,outdir='IOU_scores_pages/')
431
+
432
+
433
+
ScanSSD/IOU_lib/__init__.py ADDED
File without changes
ScanSSD/IOU_lib/iou_utils.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import cv2
4
+
5
+
6
+ class CoordinatesType(Enum):
7
+ """
8
+ Class representing if the coordinates are relative to the
9
+ image size or are absolute values.
10
+
11
+ Developed by: Rafael Padilla
12
+ Last modification: Apr 28 2018
13
+ """
14
+ Relative = 1
15
+ Absolute = 2
16
+
17
+
18
+ class BBType(Enum):
19
+ """
20
+ Class representing if the bounding box is groundtruth or not.
21
+
22
+ Developed by: Rafael Padilla
23
+ Last modification: May 24 2018
24
+ """
25
+ GroundTruth = 1
26
+ Detected = 2
27
+
28
+
29
+ class BBFormat(Enum):
30
+ """
31
+ Class representing the format of a bounding box.
32
+ It can be (X,Y,width,height) => XYWH
33
+ or (X1,Y1,X2,Y2) => XYX2Y2
34
+
35
+ Developed by: Rafael Padilla
36
+ Last modification: May 24 2018
37
+ """
38
+ XYWH = 1
39
+ XYX2Y2 = 2
40
+
41
+
42
+ # size => (width, height) of the image
43
+ # box => (X1, X2, Y1, Y2) of the bounding box
44
+ def convertToRelativeValues(size, box):
45
+ dw = 1. / (size[0])
46
+ dh = 1. / (size[1])
47
+ cx = (box[1] + box[0]) / 2.0
48
+ cy = (box[3] + box[2]) / 2.0
49
+ w = box[1] - box[0]
50
+ h = box[3] - box[2]
51
+ x = cx * dw
52
+ y = cy * dh
53
+ w = w * dw
54
+ h = h * dh
55
+ # x,y => (bounding_box_center)/width_of_the_image
56
+ # w => bounding_box_width / width_of_the_image
57
+ # h => bounding_box_height / height_of_the_image
58
+ return (x, y, w, h)
59
+
60
+
61
+ # size => (width, height) of the image
62
+ # box => (centerX, centerY, w, h) of the bounding box relative to the image
63
+ def convertToAbsoluteValues(size, box):
64
+ # w_box = round(size[0] * box[2])
65
+ # h_box = round(size[1] * box[3])
66
+ xIn = round(((2 * float(box[0]) - float(box[2])) * size[0] / 2))
67
+ yIn = round(((2 * float(box[1]) - float(box[3])) * size[1] / 2))
68
+ xEnd = xIn + round(float(box[2]) * size[0])
69
+ yEnd = yIn + round(float(box[3]) * size[1])
70
+ if xIn < 0:
71
+ xIn = 0
72
+ if yIn < 0:
73
+ yIn = 0
74
+ if xEnd >= size[0]:
75
+ xEnd = size[0] - 1
76
+ if yEnd >= size[1]:
77
+ yEnd = size[1] - 1
78
+ return (xIn, yIn, xEnd, yEnd)
79
+
80
+
81
+ def add_bb_into_image(image, bb, color=(255, 0, 0), thickness=2, label=None):
82
+ r = int(color[0])
83
+ g = int(color[1])
84
+ b = int(color[2])
85
+
86
+ font = cv2.FONT_HERSHEY_SIMPLEX
87
+ fontScale = 0.5
88
+ fontThickness = 1
89
+
90
+ x1, y1, x2, y2 = bb.getAbsoluteBoundingBox(BBFormat.XYX2Y2)
91
+ x1 = int(x1)
92
+ y1 = int(y1)
93
+ x2 = int(x2)
94
+ y2 = int(y2)
95
+ cv2.rectangle(image, (x1, y1), (x2, y2), (b, g, r), thickness)
96
+ # Add label
97
+ if label is not None:
98
+ # Get size of the text box
99
+ (tw, th) = cv2.getTextSize(label, font, fontScale, fontThickness)[0]
100
+ # Top-left coord of the textbox
101
+ (xin_bb, yin_bb) = (x1 + thickness, y1 - th + int(12.5 * fontScale))
102
+ # Checking position of the text top-left (outside or inside the bb)
103
+ if yin_bb - th <= 0: # if outside the image
104
+ yin_bb = y1 + th # put it inside the bb
105
+ r_Xin = x1 - int(thickness / 2)
106
+ r_Yin = y1 - th - int(thickness / 2)
107
+ # Draw filled rectangle to put the text in it
108
+ cv2.rectangle(image, (r_Xin, r_Yin - thickness),
109
+ (r_Xin + tw + thickness * 3, r_Yin + th + int(12.5 * fontScale)), (b, g, r),
110
+ -1)
111
+ cv2.putText(image, label, (xin_bb, yin_bb), font, fontScale, (0, 0, 0), fontThickness,
112
+ cv2.LINE_AA)
113
+ return image
ScanSSD/README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ScanSSD: Scanning Single Shot Detector for Math in Document Images
2
+
3
+
4
+ A [PyTorch](http://pytorch.org/) implementation of ScanSSD [Scanning Single Shot MultiBox Detector](https://paragmali.me/scanning-single-shot-detector-for-math-in-document-images/) by [**Parag Mali**](https://github.com/MaliParag/). It was developed using SSD implementation by [**Max deGroot**](https://github.com/amdegroot).
5
+
6
+ All credit goes to the authors of the paper and the original implementation.
7
+
8
+ ---
9
+
10
+ I have made some changes to the original implementation to make it work with the latest version of PyTorch and Python.
11
+ I have also removed some unnecessary files, in particular the ones related to dataset.