Spaces:
Runtime error
Runtime error
Commit ·
6163604
1
Parent(s): 16db6af
[Add] source
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- HybridViT/beam.py +131 -0
- HybridViT/helper.py +182 -0
- HybridViT/module/component/common/__init__.py +6 -0
- HybridViT/module/component/common/conv.py +148 -0
- HybridViT/module/component/common/droppath.py +36 -0
- HybridViT/module/component/common/gated_sum.py +36 -0
- HybridViT/module/component/common/mae_posembed.py +72 -0
- HybridViT/module/component/common/maxout.py +22 -0
- HybridViT/module/component/common/postional_encoding.py +226 -0
- HybridViT/module/component/feature_extractor/__init__.py +0 -0
- HybridViT/module/component/feature_extractor/addon_module/__init__.py +2 -0
- HybridViT/module/component/feature_extractor/addon_module/aspp.py +59 -0
- HybridViT/module/component/feature_extractor/addon_module/visual_attention.py +325 -0
- HybridViT/module/component/feature_extractor/clova_impl/__init__.py +2 -0
- HybridViT/module/component/feature_extractor/clova_impl/resnet.py +262 -0
- HybridViT/module/component/feature_extractor/clova_impl/vgg.py +27 -0
- HybridViT/module/component/feature_extractor/helpers.py +76 -0
- HybridViT/module/component/feature_extractor/vgg.py +96 -0
- HybridViT/module/component/prediction_head/__init__.py +5 -0
- HybridViT/module/component/prediction_head/addon_module/__init__.py +3 -0
- HybridViT/module/component/prediction_head/addon_module/attention1D.py +218 -0
- HybridViT/module/component/prediction_head/addon_module/attention2D.py +88 -0
- HybridViT/module/component/prediction_head/addon_module/position_encoding.py +27 -0
- HybridViT/module/component/prediction_head/seq2seq.py +268 -0
- HybridViT/module/component/prediction_head/seq2seq_v2.py +218 -0
- HybridViT/module/component/prediction_head/tfm.py +207 -0
- HybridViT/module/component/seq_modeling/__init__.py +2 -0
- HybridViT/module/component/seq_modeling/addon_module/__init__.py +1 -0
- HybridViT/module/component/seq_modeling/addon_module/patchembed.py +161 -0
- HybridViT/module/component/seq_modeling/bilstm.py +33 -0
- HybridViT/module/component/seq_modeling/vit/utils.py +59 -0
- HybridViT/module/component/seq_modeling/vit/vision_transformer.py +184 -0
- HybridViT/module/component/seq_modeling/vit_encoder.py +276 -0
- HybridViT/module/converter/__init__.py +3 -0
- HybridViT/module/converter/attn_converter.py +71 -0
- HybridViT/module/converter/builder.py +6 -0
- HybridViT/module/converter/tfm_converter.py +90 -0
- HybridViT/recog_flow.py +113 -0
- HybridViT/recognizers/__init__.py +0 -0
- HybridViT/recognizers/build_feat.py +45 -0
- HybridViT/recognizers/build_model.py +82 -0
- HybridViT/recognizers/build_pred.py +61 -0
- HybridViT/recognizers/build_seq.py +60 -0
- HybridViT/resizer.py +0 -0
- ScanSSD/IOU_lib/BoundingBox.py +164 -0
- ScanSSD/IOU_lib/Evaluator.py +87 -0
- ScanSSD/IOU_lib/IOUevaluater.py +433 -0
- ScanSSD/IOU_lib/__init__.py +0 -0
- ScanSSD/IOU_lib/iou_utils.py +113 -0
- ScanSSD/README.md +11 -0
HybridViT/beam.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import List
|
| 3 |
+
from einops import rearrange, repeat
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
class Hypothesis:
|
| 7 |
+
seq: List[int]
|
| 8 |
+
score: float
|
| 9 |
+
attn_weights: List[float]
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
seq_tensor: torch.LongTensor,
|
| 14 |
+
score: float,
|
| 15 |
+
weights: Optional[torch.FloatTensor] = None
|
| 16 |
+
) -> None:
|
| 17 |
+
raw_seq = seq_tensor.tolist()
|
| 18 |
+
|
| 19 |
+
self.seq = raw_seq
|
| 20 |
+
self.score = score
|
| 21 |
+
if weights:
|
| 22 |
+
self.attn_weights = weights.tolist()
|
| 23 |
+
assert len(self.attn_weights) == len(self.seq)
|
| 24 |
+
else:
|
| 25 |
+
self.attn_weights = None
|
| 26 |
+
|
| 27 |
+
def __len__(self):
|
| 28 |
+
if len(self.seq) != 0:
|
| 29 |
+
return len(self.seq)
|
| 30 |
+
else:
|
| 31 |
+
return 1
|
| 32 |
+
|
| 33 |
+
def __str__(self):
|
| 34 |
+
return f"seq: {self.seq}, score: {self.score}, weight: {self.attn_weights}"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Beam:
|
| 38 |
+
def __init__(self,
|
| 39 |
+
start_w=1,
|
| 40 |
+
stop_w=2,
|
| 41 |
+
ignore_w=0,
|
| 42 |
+
max_len=150,
|
| 43 |
+
viz_attn=False,
|
| 44 |
+
device='cuda'
|
| 45 |
+
):
|
| 46 |
+
self.stop_w = stop_w
|
| 47 |
+
self.start_w = start_w
|
| 48 |
+
|
| 49 |
+
self.hypotheses = torch.full(
|
| 50 |
+
(1, max_len + 2),
|
| 51 |
+
fill_value=ignore_w,
|
| 52 |
+
dtype=torch.long,
|
| 53 |
+
device=device,
|
| 54 |
+
)
|
| 55 |
+
if viz_attn:
|
| 56 |
+
self.hyp_alpha = torch.ones(1, max_len + 2, dtype=torch.float, device=device)
|
| 57 |
+
|
| 58 |
+
self.hypotheses[:, 0] = start_w
|
| 59 |
+
self.hyp_scores = torch.zeros(1, dtype=torch.float, device=device)
|
| 60 |
+
self.completed_hypotheses: List[Hypothesis] = []
|
| 61 |
+
self.device = device
|
| 62 |
+
self.viz_attn = viz_attn
|
| 63 |
+
|
| 64 |
+
def advance(self, next_log_probs, step, beam_size):
|
| 65 |
+
vocab_size = next_log_probs.shape[1]
|
| 66 |
+
live_hyp_num = beam_size - len(self.completed_hypotheses)
|
| 67 |
+
exp_hyp_scores = repeat(self.hyp_scores, "b -> b e", e=vocab_size)
|
| 68 |
+
continuous_hyp_scores = rearrange(exp_hyp_scores + next_log_probs, "b e -> (b e)")
|
| 69 |
+
top_cand_hyp_scores, top_cand_hyp_pos = torch.topk(
|
| 70 |
+
continuous_hyp_scores, k=live_hyp_num
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
prev_hyp_ids = top_cand_hyp_pos // vocab_size
|
| 74 |
+
hyp_word_ids = top_cand_hyp_pos % vocab_size
|
| 75 |
+
|
| 76 |
+
step += 1
|
| 77 |
+
new_hypotheses = []
|
| 78 |
+
new_hyp_scores = []
|
| 79 |
+
|
| 80 |
+
for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip(
|
| 81 |
+
prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores
|
| 82 |
+
):
|
| 83 |
+
cand_new_hyp_score = cand_new_hyp_score.detach().item()
|
| 84 |
+
self.hypotheses[prev_hyp_id, step] = hyp_word_id
|
| 85 |
+
|
| 86 |
+
if hyp_word_id == self.stop_w:
|
| 87 |
+
self.completed_hypotheses.append(
|
| 88 |
+
Hypothesis(
|
| 89 |
+
seq_tensor=self.hypotheses[prev_hyp_id, 1:step+1]
|
| 90 |
+
.detach()
|
| 91 |
+
.clone(), # remove START_W at first
|
| 92 |
+
score=cand_new_hyp_score,
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
new_hypotheses.append(self.hypotheses[prev_hyp_id].detach().clone())
|
| 97 |
+
new_hyp_scores.append(cand_new_hyp_score)
|
| 98 |
+
|
| 99 |
+
return new_hypotheses, new_hyp_scores
|
| 100 |
+
|
| 101 |
+
def get_incomplete_inds(self, hyp_word_ids):
|
| 102 |
+
return [ind for ind, next_word in enumerate(hyp_word_ids) if
|
| 103 |
+
next_word != self.stop_w]
|
| 104 |
+
|
| 105 |
+
def get_complete_inds(self, hyp_word_ids, incomplete_inds):
|
| 106 |
+
return list(set(range(len(hyp_word_ids))) - set(incomplete_inds))
|
| 107 |
+
|
| 108 |
+
def set_current_state(self, hypotheses):
|
| 109 |
+
"Set the outputs for the current timestep."
|
| 110 |
+
self.hypotheses = torch.stack(hypotheses, dim=0)
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
def set_current_score(self, hyp_scores):
|
| 114 |
+
"Set the scores for the current timestep."
|
| 115 |
+
self.hyp_scores = torch.tensor(
|
| 116 |
+
hyp_scores, dtype=torch.float, device=self.device
|
| 117 |
+
)
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
def done(self, beam_size):
|
| 121 |
+
return len(self.completed_hypotheses) == beam_size
|
| 122 |
+
|
| 123 |
+
def set_hypothesis(self):
|
| 124 |
+
if len(self.completed_hypotheses) == 0:
|
| 125 |
+
self.completed_hypotheses.append(
|
| 126 |
+
Hypothesis(
|
| 127 |
+
seq_tensor=self.hypotheses[0, 1:].detach().clone(),
|
| 128 |
+
score=self.hyp_scores[0].detach().item(),
|
| 129 |
+
)
|
| 130 |
+
)
|
| 131 |
+
return
|
HybridViT/helper.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from typing import Dict
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import cv2
|
| 10 |
+
import math
|
| 11 |
+
import albumentations as alb
|
| 12 |
+
from albumentations.pytorch.transforms import ToTensorV2
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
from itertools import repeat
|
| 15 |
+
import collections.abc
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# From PyTorch internals
|
| 19 |
+
def _ntuple(n):
|
| 20 |
+
def parse(x):
|
| 21 |
+
if isinstance(x, collections.abc.Iterable):
|
| 22 |
+
return x
|
| 23 |
+
return tuple(repeat(x, n))
|
| 24 |
+
return parse
|
| 25 |
+
|
| 26 |
+
to_3tuple = _ntuple(3)
|
| 27 |
+
|
| 28 |
+
def clean_state_dict(state_dict):
|
| 29 |
+
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
| 30 |
+
cleaned_state_dict = OrderedDict()
|
| 31 |
+
for k, v in state_dict.items():
|
| 32 |
+
name = k[7:] if k.startswith('module.') else k
|
| 33 |
+
cleaned_state_dict[name] = v
|
| 34 |
+
return cleaned_state_dict
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def math_transform(mean, std, is_gray: bool):
|
| 38 |
+
test_transform = []
|
| 39 |
+
normalize = [
|
| 40 |
+
alb.CLAHE(clip_limit=2, tile_grid_size=(2, 2), always_apply=True),
|
| 41 |
+
alb.Normalize(to_3tuple(mean), to_3tuple(std)),
|
| 42 |
+
ToTensorV2()
|
| 43 |
+
]
|
| 44 |
+
if is_gray:
|
| 45 |
+
test_transform += [alb.ToGray(always_apply=True)]
|
| 46 |
+
test_transform += normalize
|
| 47 |
+
|
| 48 |
+
test_transform = alb.Compose([*test_transform])
|
| 49 |
+
return test_transform
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def pad(img: Image.Image, divable=32):
|
| 53 |
+
"""Pad an Image to the next full divisible value of `divable`. Also normalizes the PIL.image and invert if needed.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
img (PIL.Image): input PIL.image
|
| 57 |
+
divable (int, optional): . Defaults to 32.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
PIL.Image
|
| 61 |
+
"""
|
| 62 |
+
data = np.array(img.convert('LA'))
|
| 63 |
+
|
| 64 |
+
data = (data-data.min())/(data.max()-data.min())*255
|
| 65 |
+
if data[..., 0].mean() > 128:
|
| 66 |
+
gray = 255*(data[..., 0] < 128).astype(np.uint8) # To invert the text to white
|
| 67 |
+
else:
|
| 68 |
+
gray = 255*(data[..., 0] > 128).astype(np.uint8)
|
| 69 |
+
data[..., 0] = 255-data[..., 0]
|
| 70 |
+
|
| 71 |
+
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
|
| 72 |
+
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
|
| 73 |
+
rect = data[b:b+h, a:a+w]
|
| 74 |
+
|
| 75 |
+
if rect[..., -1].var() == 0:
|
| 76 |
+
im = Image.fromarray((rect[..., 0]).astype(np.uint8)).convert('L')
|
| 77 |
+
else:
|
| 78 |
+
im = Image.fromarray((255-rect[..., -1]).astype(np.uint8)).convert('L')
|
| 79 |
+
dims = []
|
| 80 |
+
|
| 81 |
+
for x in [w, h]:
|
| 82 |
+
div, mod = divmod(x, divable)
|
| 83 |
+
dims.append(divable*(div + (1 if mod > 0 else 0)))
|
| 84 |
+
|
| 85 |
+
padded = Image.new('L', dims, 255)
|
| 86 |
+
padded.paste(im, im.getbbox())
|
| 87 |
+
|
| 88 |
+
return padded
|
| 89 |
+
|
| 90 |
+
def get_divisible_size(ori_h, ori_w, max_dimension=None, scale_factor=32):
|
| 91 |
+
new_h, new_w = ori_h, ori_w
|
| 92 |
+
if ori_h % scale_factor:
|
| 93 |
+
new_h = math.ceil(ori_h/scale_factor)*scale_factor
|
| 94 |
+
if new_h > max_dimension[0]:
|
| 95 |
+
new_h = math.floor(ori_h/scale_factor)*scale_factor
|
| 96 |
+
|
| 97 |
+
if ori_w % scale_factor:
|
| 98 |
+
new_w = math.ceil(ori_w/scale_factor)*scale_factor
|
| 99 |
+
if new_w > max_dimension[1]:
|
| 100 |
+
new_w = math.floor(ori_w/scale_factor)*scale_factor
|
| 101 |
+
|
| 102 |
+
return int(new_h),int(new_w)
|
| 103 |
+
|
| 104 |
+
def minmax_size(img, max_dimensions=None, min_dimensions=None, is_gray=True):
|
| 105 |
+
if max_dimensions is not None:
|
| 106 |
+
ratios = [a/b for a, b in zip(list(img.size)[::-1], max_dimensions)]
|
| 107 |
+
if any([r > 1 for r in ratios]):
|
| 108 |
+
size = np.array(img.size)/max(ratios)
|
| 109 |
+
new_h, new_w = get_divisible_size(size[1], size[0], max_dimensions)
|
| 110 |
+
img = img.resize((new_w, new_h), Image.LANCZOS)
|
| 111 |
+
|
| 112 |
+
if min_dimensions is not None:
|
| 113 |
+
ratios = [a/b for a, b in zip(list(img.size)[::-1], min_dimensions)]
|
| 114 |
+
if any([r < 1 for r in ratios]):
|
| 115 |
+
new_h, new_w = img.size[1] / min(ratios), img.size[0] / min(ratios)
|
| 116 |
+
new_h, new_w = get_divisible_size(new_h, new_w, max_dimensions)
|
| 117 |
+
if is_gray:
|
| 118 |
+
MODE = 'L'
|
| 119 |
+
BACKGROUND = 255
|
| 120 |
+
padded_im = Image.new(MODE, (new_w, new_h), BACKGROUND)
|
| 121 |
+
padded_im.paste(img, img.getbbox())
|
| 122 |
+
img = padded_im
|
| 123 |
+
|
| 124 |
+
return img
|
| 125 |
+
|
| 126 |
+
def resize(resizer, img: Image.Image, opt: Dict):
|
| 127 |
+
# for math recognition problem image alway in grayscale mode
|
| 128 |
+
img = img.convert('L')
|
| 129 |
+
assert isinstance(opt, Dict)
|
| 130 |
+
assert "imgH" in opt
|
| 131 |
+
assert "imgW" in opt
|
| 132 |
+
expected_H = opt['imgH']
|
| 133 |
+
|
| 134 |
+
if expected_H is None:
|
| 135 |
+
max_dimensions = opt['max_dimension'] #can be bigger than max dim in training set
|
| 136 |
+
min_dimensions = opt['min_dimension']
|
| 137 |
+
#equal to min dim in trainign set
|
| 138 |
+
test_transform = math_transform(opt['mean'], opt['std'], not opt['rgb'])
|
| 139 |
+
try:
|
| 140 |
+
new_img = minmax_size(pad(img) if opt['pad'] else img, max_dimensions, min_dimensions, not opt['rgb'])
|
| 141 |
+
|
| 142 |
+
if not resizer:
|
| 143 |
+
new_img = np.asarray(new_img.convert('RGB')).astype('uint8')
|
| 144 |
+
new_img = test_transform(image=new_img)['image']
|
| 145 |
+
if not opt['rgb']: new_img = new_img[:1]
|
| 146 |
+
new_img = new_img.unsqueeze(0)
|
| 147 |
+
new_img = new_img.float()
|
| 148 |
+
else:
|
| 149 |
+
with torch.no_grad():
|
| 150 |
+
input_image = new_img.convert('RGB').copy()
|
| 151 |
+
r, w, h = 1, input_image.size[0], input_image.size[1]
|
| 152 |
+
for i in range(20):
|
| 153 |
+
h = int(h * r)
|
| 154 |
+
new_img = pad(minmax_size(input_image.resize((w, h), Image.BILINEAR if r > 1 else Image.LANCZOS),
|
| 155 |
+
max_dimensions,
|
| 156 |
+
min_dimensions,
|
| 157 |
+
not opt['rgb']
|
| 158 |
+
))
|
| 159 |
+
t = test_transform(image=np.array(new_img.convert('RGB')).astype('uint8'))['image']
|
| 160 |
+
if not opt['rgb']: t = t[:1]
|
| 161 |
+
t = t.unsqueeze(0)
|
| 162 |
+
t = t.float()
|
| 163 |
+
w = (resizer(t.to(opt['device'])).argmax(-1).item()+1)*opt['min_dimension'][1]
|
| 164 |
+
|
| 165 |
+
if (w == new_img.size[0]):
|
| 166 |
+
break
|
| 167 |
+
|
| 168 |
+
r = w/new_img.size[0]
|
| 169 |
+
|
| 170 |
+
new_img = t
|
| 171 |
+
except ValueError as e:
|
| 172 |
+
print('Error:', e)
|
| 173 |
+
new_img = np.asarray(img.convert('RGB')).astype('uint8')
|
| 174 |
+
assert len(new_img.shape) == 3 and new_img.shape[2] == 3
|
| 175 |
+
new_img = test_transform(image=new_img)['image']
|
| 176 |
+
if not opt['rgb']: new_img = new_img[:1]
|
| 177 |
+
new_img = new_img.unsqueeze(0)
|
| 178 |
+
h, w = new_img.shape[2:]
|
| 179 |
+
new_img = F.pad(new_img, (0, max_dimensions[1]-w, 0, max_dimensions[0]-h), value=1)
|
| 180 |
+
|
| 181 |
+
assert len(new_img.shape) == 4, f'{new_img.shape}'
|
| 182 |
+
return new_img
|
HybridViT/module/component/common/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .conv import *
|
| 2 |
+
from .droppath import *
|
| 3 |
+
from .gated_sum import *
|
| 4 |
+
from .maxout import *
|
| 5 |
+
from .postional_encoding import *
|
| 6 |
+
from .mae_posembed import *
|
HybridViT/module/component/common/conv.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
__all__ = ['ConvMLP', 'ConvModule']
|
| 8 |
+
|
| 9 |
+
class LayerNorm2d(nn.LayerNorm):
|
| 10 |
+
""" LayerNorm for channels of '2D' spatial BCHW tensors """
|
| 11 |
+
def __init__(self, num_channels):
|
| 12 |
+
super().__init__(num_channels)
|
| 13 |
+
|
| 14 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
return F.layer_norm(
|
| 16 |
+
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
| 17 |
+
|
| 18 |
+
class DepthwiseSeparableConv2d(nn.Module):
|
| 19 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
| 20 |
+
super(DepthwiseSeparableConv2d, self).__init__()
|
| 21 |
+
self.depthwise = nn.Conv2d(
|
| 22 |
+
in_channels,
|
| 23 |
+
in_channels,
|
| 24 |
+
kernel_size=kernel_size,
|
| 25 |
+
dilation=dilation,
|
| 26 |
+
padding=padding,
|
| 27 |
+
stride=stride,
|
| 28 |
+
bias=bias,
|
| 29 |
+
groups=in_channels,
|
| 30 |
+
)
|
| 31 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups, bias=bias)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
out = self.depthwise(x)
|
| 35 |
+
out = self.pointwise(out)
|
| 36 |
+
return out
|
| 37 |
+
|
| 38 |
+
class ConvMLP(nn.Module):
|
| 39 |
+
def __init__(self, in_channels, out_channels=None, hidden_channels=None, drop=0.25):
|
| 40 |
+
super().__init__()
|
| 41 |
+
out_channels = in_channels or out_channels
|
| 42 |
+
hidden_channels = in_channels or hidden_channels
|
| 43 |
+
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=True)
|
| 44 |
+
self.norm = LayerNorm2d(hidden_channels)
|
| 45 |
+
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=True)
|
| 46 |
+
self.act = nn.ReLU()
|
| 47 |
+
self.drop = nn.Dropout(drop)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
x = self.fc1(x)
|
| 51 |
+
x = self.norm(x)
|
| 52 |
+
x = self.act(x)
|
| 53 |
+
x = self.drop(x)
|
| 54 |
+
x = self.fc2(x)
|
| 55 |
+
return x
|
| 56 |
+
|
| 57 |
+
class ConvModule(nn.Module):
|
| 58 |
+
def __init__(self,
|
| 59 |
+
in_channels,
|
| 60 |
+
out_channels,
|
| 61 |
+
kernel_size,
|
| 62 |
+
stride=1,
|
| 63 |
+
padding=0,
|
| 64 |
+
dilation=1,
|
| 65 |
+
groups=1,
|
| 66 |
+
bias='auto',
|
| 67 |
+
conv_layer:Optional[nn.Module]=nn.Conv2d,
|
| 68 |
+
norm_layer:Optional[nn.Module]=nn.BatchNorm2d,
|
| 69 |
+
act_layer:Optional[nn.Module]=nn.ReLU,
|
| 70 |
+
inplace=True,
|
| 71 |
+
with_spectral_norm=False,
|
| 72 |
+
padding_mode='zeros',
|
| 73 |
+
order=('conv', 'norm', 'act')
|
| 74 |
+
):
|
| 75 |
+
official_padding_mode = ['zeros', 'circular']
|
| 76 |
+
nonofficial_padding_mode = dict(zero=nn.ZeroPad2d, reflect=nn.ReflectionPad2d, replicate=nn.ReplicationPad2d)
|
| 77 |
+
self.with_spectral_norm = with_spectral_norm
|
| 78 |
+
self.with_explicit_padding = padding_mode not in official_padding_mode
|
| 79 |
+
self.order = order
|
| 80 |
+
assert isinstance(self.order, tuple) and len(self.order) == 3
|
| 81 |
+
assert set(order) == set(['conv', 'norm', 'act'])
|
| 82 |
+
|
| 83 |
+
self.with_norm = norm_layer is not None
|
| 84 |
+
self.with_act = act_layer is not None
|
| 85 |
+
|
| 86 |
+
if bias == 'auto':
|
| 87 |
+
bias = not self.with_norm
|
| 88 |
+
self.with_bias = bias
|
| 89 |
+
|
| 90 |
+
if self.with_norm and self.with_bias:
|
| 91 |
+
warnings.warn('ConvModule has norm and bias at the same time')
|
| 92 |
+
|
| 93 |
+
if self.with_explicit_padding:
|
| 94 |
+
assert padding_mode in list(nonofficial_padding_mode), "Not implemented padding algorithm"
|
| 95 |
+
self.padding_layer = nonofficial_padding_mode[padding_mode]
|
| 96 |
+
|
| 97 |
+
# reset padding to 0 for conv module
|
| 98 |
+
conv_padding = 0 if self.with_explicit_padding else padding
|
| 99 |
+
|
| 100 |
+
self.conv = conv_layer(
|
| 101 |
+
in_channels,
|
| 102 |
+
out_channels,
|
| 103 |
+
kernel_size,
|
| 104 |
+
stride=stride,
|
| 105 |
+
padding=conv_padding,
|
| 106 |
+
dilation=dilation,
|
| 107 |
+
groups=groups,
|
| 108 |
+
bias=bias
|
| 109 |
+
)
|
| 110 |
+
self.in_channels = self.conv.in_channels
|
| 111 |
+
self.out_channels = self.conv.out_channels
|
| 112 |
+
self.kernel_size = self.conv.kernel_size
|
| 113 |
+
self.stride = self.conv.stride
|
| 114 |
+
self.padding = padding
|
| 115 |
+
self.dilation = self.conv.dilation
|
| 116 |
+
self.transposed = self.conv.transposed
|
| 117 |
+
self.output_padding = self.conv.output_padding
|
| 118 |
+
self.groups = self.conv.groups
|
| 119 |
+
|
| 120 |
+
if self.with_spectral_norm:
|
| 121 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
| 122 |
+
|
| 123 |
+
# build normalization layers
|
| 124 |
+
if self.with_norm:
|
| 125 |
+
# norm layer is after conv layer
|
| 126 |
+
if order.index('norm') > order.index('conv'):
|
| 127 |
+
norm_channels = out_channels
|
| 128 |
+
else:
|
| 129 |
+
norm_channels = in_channels
|
| 130 |
+
self.norm = norm_layer(norm_channels)
|
| 131 |
+
|
| 132 |
+
if self.with_act:
|
| 133 |
+
if act_layer not in [nn.Tanh, nn.PReLU, nn.Sigmoid]:
|
| 134 |
+
self.act = act_layer()
|
| 135 |
+
else:
|
| 136 |
+
self.act = act_layer(inplace=inplace)
|
| 137 |
+
|
| 138 |
+
def forward(self, x, activate=True, norm=True):
|
| 139 |
+
for layer in self.order:
|
| 140 |
+
if layer == 'conv':
|
| 141 |
+
if self.with_explicit_padding:
|
| 142 |
+
x = self.padding_layer(x)
|
| 143 |
+
x = self.conv(x)
|
| 144 |
+
elif layer == 'norm' and norm and self.with_norm:
|
| 145 |
+
x = self.norm(x)
|
| 146 |
+
elif layer == 'act' and activate and self.with_act:
|
| 147 |
+
x = self.act(x)
|
| 148 |
+
return x
|
HybridViT/module/component/common/droppath.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
__all__ = ['DropPath']
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
| 7 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 8 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 9 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 10 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 11 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 12 |
+
'survival rate' as the argument.
|
| 13 |
+
"""
|
| 14 |
+
if drop_prob == 0. or not training:
|
| 15 |
+
return x
|
| 16 |
+
keep_prob = 1 - drop_prob
|
| 17 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 18 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 19 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 20 |
+
random_tensor.div_(keep_prob)
|
| 21 |
+
return x * random_tensor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DropPath(nn.Module):
|
| 25 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 26 |
+
"""
|
| 27 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
| 28 |
+
super(DropPath, self).__init__()
|
| 29 |
+
self.drop_prob = drop_prob
|
| 30 |
+
self.scale_by_keep = scale_by_keep
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 34 |
+
|
| 35 |
+
def extra_repr(self):
|
| 36 |
+
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
HybridViT/module/component/common/gated_sum.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class GatedSum(torch.nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
This `Module` represents a gated sum of two tensors `a` and `b`. Specifically:
|
| 7 |
+
```
|
| 8 |
+
f = activation(W [a; b])
|
| 9 |
+
out = f * a + (1 - f) * b
|
| 10 |
+
```
|
| 11 |
+
# Parameters
|
| 12 |
+
input_dim : `int`, required
|
| 13 |
+
The dimensionality of the input. We assume the input have shape `(..., input_dim)`.
|
| 14 |
+
activation : `Activation`, optional (default = `torch.nn.Sigmoid()`)
|
| 15 |
+
The activation function to use.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, input_dim: int, activation = torch.nn.Sigmoid()) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.input_dim = input_dim
|
| 21 |
+
self._gate = torch.nn.Linear(input_dim * 2, 1)
|
| 22 |
+
self._activation = activation
|
| 23 |
+
|
| 24 |
+
def get_input_dim(self):
|
| 25 |
+
return self.input_dim
|
| 26 |
+
|
| 27 |
+
def get_output_dim(self):
|
| 28 |
+
return self.input_dim
|
| 29 |
+
|
| 30 |
+
def forward(self, input_a: torch.Tensor, input_b: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
if input_a.size() != input_b.size():
|
| 32 |
+
raise ValueError("The input must have the same size.")
|
| 33 |
+
if input_a.size(-1) != self.input_dim:
|
| 34 |
+
raise ValueError("Input size must match `input_dim`.")
|
| 35 |
+
gate_value = self._activation(self._gate(torch.cat([input_a, input_b], -1)))
|
| 36 |
+
return gate_value * input_a + (1 - gate_value) * input_b
|
HybridViT/module/component/common/mae_posembed.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# Position embedding utils
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
# --------------------------------------------------------
|
| 12 |
+
# 2D sine-cosine position embedding
|
| 13 |
+
# References:
|
| 14 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
| 15 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
| 16 |
+
# --------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size_H, grid_size_W, cls_token=False):
|
| 19 |
+
"""
|
| 20 |
+
grid_size: int of the grid height and width
|
| 21 |
+
return:
|
| 22 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 23 |
+
"""
|
| 24 |
+
grid_h = np.arange(grid_size_H, dtype=np.float32)
|
| 25 |
+
grid_w = np.arange(grid_size_W, dtype=np.float32)
|
| 26 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 27 |
+
grid = np.stack(grid, axis=0)
|
| 28 |
+
|
| 29 |
+
grid = grid.reshape([2, 1, grid_size_H, grid_size_W])
|
| 30 |
+
|
| 31 |
+
print('new grid shape', grid.shape)
|
| 32 |
+
|
| 33 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 34 |
+
if cls_token:
|
| 35 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 36 |
+
return pos_embed
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 40 |
+
assert embed_dim % 2 == 0
|
| 41 |
+
|
| 42 |
+
# use half of dimensions to encode grid_h
|
| 43 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 44 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 45 |
+
|
| 46 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 47 |
+
return emb
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 51 |
+
"""
|
| 52 |
+
embed_dim: output dimension for each position
|
| 53 |
+
pos: a list of positions to be encoded: size (M,)
|
| 54 |
+
out: (M, D)
|
| 55 |
+
"""
|
| 56 |
+
assert embed_dim % 2 == 0
|
| 57 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
| 58 |
+
omega /= embed_dim / 2.
|
| 59 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 60 |
+
|
| 61 |
+
pos = pos.reshape(-1) # (M,)
|
| 62 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 63 |
+
|
| 64 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 65 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 66 |
+
|
| 67 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 68 |
+
return emb
|
| 69 |
+
|
| 70 |
+
if __name__ == '__main__':
|
| 71 |
+
pos_embed = get_2d_sincos_pos_embed(256, 800, 800, True)
|
| 72 |
+
print(pos_embed.shape)
|
HybridViT/module/component/common/maxout.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
class Maxout(nn.Module):
|
| 4 |
+
"""
|
| 5 |
+
Maxout makes pools from the last dimension and keeps only the maximum value from
|
| 6 |
+
each pool.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, pool_size):
|
| 10 |
+
"""
|
| 11 |
+
Args:
|
| 12 |
+
pool_size (int): Number of elements per pool
|
| 13 |
+
"""
|
| 14 |
+
super(Maxout, self).__init__()
|
| 15 |
+
self.pool_size = pool_size
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
[*shape, last] = x.size()
|
| 19 |
+
out = x.view(*shape, last // self.pool_size, self.pool_size)
|
| 20 |
+
out, _ = out.max(-1)
|
| 21 |
+
return out
|
| 22 |
+
|
HybridViT/module/component/common/postional_encoding.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Adaptive2DPositionalEncoding(nn.Module):
|
| 12 |
+
"""Implement Adaptive 2D positional encoder for SATRN, see
|
| 13 |
+
`SATRN <https://arxiv.org/abs/1910.04396>`_
|
| 14 |
+
Modified from https://github.com/Media-Smart/vedastr
|
| 15 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 16 |
+
Args:
|
| 17 |
+
d_hid (int): Dimensions of hidden layer.
|
| 18 |
+
n_height (int): Max height of the 2D feature output.
|
| 19 |
+
n_width (int): Max width of the 2D feature output.
|
| 20 |
+
dropout (int): Size of hidden layers of the model.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self,
|
| 24 |
+
d_hid=512,
|
| 25 |
+
n_height=100,
|
| 26 |
+
n_width=100,
|
| 27 |
+
dropout=0.1,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid)
|
| 32 |
+
h_position_encoder = h_position_encoder.transpose(0, 1)
|
| 33 |
+
h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1)
|
| 34 |
+
|
| 35 |
+
w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid)
|
| 36 |
+
w_position_encoder = w_position_encoder.transpose(0, 1)
|
| 37 |
+
w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width)
|
| 38 |
+
|
| 39 |
+
self.register_buffer('h_position_encoder', h_position_encoder)
|
| 40 |
+
self.register_buffer('w_position_encoder', w_position_encoder)
|
| 41 |
+
|
| 42 |
+
self.h_scale = self.scale_factor_generate(d_hid)
|
| 43 |
+
self.w_scale = self.scale_factor_generate(d_hid)
|
| 44 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 45 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 46 |
+
|
| 47 |
+
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
| 48 |
+
"""Sinusoid position encoding table."""
|
| 49 |
+
denominator = torch.Tensor([
|
| 50 |
+
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
| 51 |
+
for hid_j in range(d_hid)
|
| 52 |
+
])
|
| 53 |
+
denominator = denominator.view(1, -1)
|
| 54 |
+
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
|
| 55 |
+
sinusoid_table = pos_tensor * denominator
|
| 56 |
+
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
|
| 57 |
+
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
|
| 58 |
+
|
| 59 |
+
return sinusoid_table
|
| 60 |
+
|
| 61 |
+
def scale_factor_generate(self, d_hid):
|
| 62 |
+
scale_factor = nn.Sequential(
|
| 63 |
+
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True),
|
| 64 |
+
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid())
|
| 65 |
+
|
| 66 |
+
return scale_factor
|
| 67 |
+
|
| 68 |
+
def init_weight(self):
|
| 69 |
+
for m in self.modules():
|
| 70 |
+
if isinstance(m, nn.Conv2d):
|
| 71 |
+
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('ReLU'))
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
b, c, h, w = x.size()
|
| 75 |
+
|
| 76 |
+
avg_pool = self.pool(x)
|
| 77 |
+
|
| 78 |
+
h_pos_encoding = \
|
| 79 |
+
self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
|
| 80 |
+
w_pos_encoding = \
|
| 81 |
+
self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
|
| 82 |
+
|
| 83 |
+
out = x + h_pos_encoding + w_pos_encoding
|
| 84 |
+
|
| 85 |
+
out = self.dropout(out)
|
| 86 |
+
|
| 87 |
+
return out
|
| 88 |
+
|
| 89 |
+
class PositionalEncoding2D(nn.Module):
|
| 90 |
+
"""2-D positional encodings for the feature maps produced by the encoder.
|
| 91 |
+
Following https://arxiv.org/abs/2103.06450 by Sumeet Singh.
|
| 92 |
+
Reference:
|
| 93 |
+
https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2021-labs/blob/main/lab9/text_recognizer/models/transformer_util.py
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000) -> None:
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.d_model = d_model
|
| 99 |
+
assert d_model % 2 == 0, f"Embedding depth {d_model} is not even"
|
| 100 |
+
pe = self.make_pe(d_model, max_h, max_w) # (d_model, max_h, max_w)
|
| 101 |
+
self.register_buffer("pe", pe)
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def make_pe(d_model: int, max_h: int, max_w: int) -> Tensor:
|
| 105 |
+
"""Compute positional encoding."""
|
| 106 |
+
pe_h = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2)
|
| 107 |
+
pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w)
|
| 108 |
+
|
| 109 |
+
pe_w = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2)
|
| 110 |
+
pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w)
|
| 111 |
+
|
| 112 |
+
pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w)
|
| 113 |
+
return pe
|
| 114 |
+
|
| 115 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 116 |
+
"""Forward pass.
|
| 117 |
+
Args:
|
| 118 |
+
x: (B, d_model, H, W)
|
| 119 |
+
Returns:
|
| 120 |
+
(B, d_model, H, W)
|
| 121 |
+
"""
|
| 122 |
+
assert x.shape[1] == self.pe.shape[0] # type: ignore
|
| 123 |
+
x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class PositionalEncoding1D(nn.Module):
|
| 128 |
+
"""Classic Attention-is-all-you-need positional encoding."""
|
| 129 |
+
|
| 130 |
+
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000) -> None:
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 133 |
+
pe = self.make_pe(d_model, max_len) # (max_len, 1, d_model)
|
| 134 |
+
self.register_buffer("pe", pe)
|
| 135 |
+
|
| 136 |
+
@staticmethod
|
| 137 |
+
def make_pe(d_model: int, max_len: int) -> Tensor:
|
| 138 |
+
"""Compute positional encoding."""
|
| 139 |
+
pe = torch.zeros(max_len, d_model)
|
| 140 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 141 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 142 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 143 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 144 |
+
pe = pe.unsqueeze(1)
|
| 145 |
+
return pe
|
| 146 |
+
|
| 147 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 148 |
+
"""Forward pass.
|
| 149 |
+
Args:
|
| 150 |
+
x: (S, B, d_model)
|
| 151 |
+
Returns:
|
| 152 |
+
(S, B, d_model)
|
| 153 |
+
"""
|
| 154 |
+
assert x.shape[2] == self.pe.shape[2] # type: ignore
|
| 155 |
+
x = x + self.pe[: x.size(0)] # type: ignore
|
| 156 |
+
return self.dropout(x)
|
| 157 |
+
|
| 158 |
+
Size_ = Tuple[int, int]
|
| 159 |
+
|
| 160 |
+
class PosConv(nn.Module):
|
| 161 |
+
# PEG from https://arxiv.org/abs/2102.10882
|
| 162 |
+
def __init__(self, in_chans, embed_dim=768, stride=1):
|
| 163 |
+
super(PosConv, self).__init__()
|
| 164 |
+
self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), )
|
| 165 |
+
self.stride = stride
|
| 166 |
+
|
| 167 |
+
def forward(self, x, size: Size_):
|
| 168 |
+
B, N, C = x.shape
|
| 169 |
+
cls_token, feat_token = x[:, 0], x[:, 1:]
|
| 170 |
+
cnn_feat_token = feat_token.transpose(1, 2).view(B, C, *size)
|
| 171 |
+
x = self.proj(cnn_feat_token)
|
| 172 |
+
if self.stride == 1:
|
| 173 |
+
x += cnn_feat_token
|
| 174 |
+
x = x.flatten(2).transpose(1, 2)
|
| 175 |
+
x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
|
| 176 |
+
return x
|
| 177 |
+
|
| 178 |
+
def no_weight_decay(self):
|
| 179 |
+
return ['proj.%d.weight' % i for i in range(4)]
|
| 180 |
+
|
| 181 |
+
class PosConv1D(nn.Module):
|
| 182 |
+
# PEG from https://arxiv.org/abs/2102.10882
|
| 183 |
+
def __init__(self, in_chans, embed_dim=768, stride=1):
|
| 184 |
+
super(PosConv1D, self).__init__()
|
| 185 |
+
self.proj = nn.Sequential(nn.Conv1d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), )
|
| 186 |
+
self.stride = stride
|
| 187 |
+
|
| 188 |
+
def forward(self, x, size: int):
|
| 189 |
+
B, N, C = x.shape
|
| 190 |
+
cls_token, feat_token = x[:, 0], x[:, 1:]
|
| 191 |
+
cnn_feat_token = feat_token.transpose(1, 2).view(B, C, size)
|
| 192 |
+
x = self.proj(cnn_feat_token)
|
| 193 |
+
if self.stride == 1:
|
| 194 |
+
x += cnn_feat_token
|
| 195 |
+
x = x.transpose(1, 2)
|
| 196 |
+
x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
|
| 197 |
+
return x
|
| 198 |
+
|
| 199 |
+
def no_weight_decay(self):
|
| 200 |
+
return ['proj.%d.weight' % i for i in range(4)]
|
| 201 |
+
|
| 202 |
+
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), old_grid_shape=()):
|
| 203 |
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
| 204 |
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
| 205 |
+
|
| 206 |
+
print('Resized position embedding: %s to %s'%(posemb.shape, posemb_new.shape))
|
| 207 |
+
ntok_new = posemb_new.shape[1]
|
| 208 |
+
|
| 209 |
+
if num_tokens:
|
| 210 |
+
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
| 211 |
+
ntok_new -= num_tokens
|
| 212 |
+
else:
|
| 213 |
+
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
| 214 |
+
|
| 215 |
+
if not len(gs_new): # backwards compatibility
|
| 216 |
+
gs_new = [int(math.sqrt(ntok_new))] * 2
|
| 217 |
+
|
| 218 |
+
assert len(gs_new) >= 2
|
| 219 |
+
|
| 220 |
+
print('Position embedding grid-size from %s to %s'%(old_grid_shape, gs_new))
|
| 221 |
+
posemb_grid = posemb_grid.reshape(1, old_grid_shape[0], old_grid_shape[1], -1).permute(0, 3, 1, 2)
|
| 222 |
+
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
| 223 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
| 224 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
| 225 |
+
|
| 226 |
+
return posemb
|
HybridViT/module/component/feature_extractor/__init__.py
ADDED
|
File without changes
|
HybridViT/module/component/feature_extractor/addon_module/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .aspp import *
|
| 2 |
+
from .visual_attention import *
|
HybridViT/module/component/feature_extractor/addon_module/aspp.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
__all__ = ['ASPP']
|
| 6 |
+
|
| 7 |
+
class ASPPModule(nn.Module):
|
| 8 |
+
def __init__(self, inplanes, planes, kernel_size, padding, dilation):
|
| 9 |
+
super(ASPPModule, self).__init__()
|
| 10 |
+
self.atrous_conv = nn.Conv2d(
|
| 11 |
+
inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False
|
| 12 |
+
)
|
| 13 |
+
self.relu = nn.ReLU(inplace=True)
|
| 14 |
+
|
| 15 |
+
def forward(self, x): # skipcq: PYL-W0221
|
| 16 |
+
x = self.atrous_conv(x)
|
| 17 |
+
x = self.relu(x)
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
class ASPP(nn.Module):
|
| 21 |
+
def __init__(self, inplanes: int, output_stride: int, output_features: int, dropout=0.5):
|
| 22 |
+
super(ASPP, self).__init__()
|
| 23 |
+
|
| 24 |
+
if output_stride == 32:
|
| 25 |
+
dilations = [1, 3, 6, 9]
|
| 26 |
+
elif output_stride == 16:
|
| 27 |
+
dilations = [1, 6, 12, 18]
|
| 28 |
+
elif output_stride == 8:
|
| 29 |
+
dilations = [1, 12, 24, 36]
|
| 30 |
+
else:
|
| 31 |
+
raise NotImplementedError
|
| 32 |
+
|
| 33 |
+
self.aspp1 = ASPPModule(inplanes, output_features, 1, padding=0, dilation=dilations[0])
|
| 34 |
+
self.aspp2 = ASPPModule(inplanes, output_features, 3, padding=dilations[1], dilation=dilations[1])
|
| 35 |
+
self.aspp3 = ASPPModule(inplanes, output_features, 3, padding=dilations[2], dilation=dilations[2])
|
| 36 |
+
self.aspp4 = ASPPModule(inplanes, output_features, 3, padding=dilations[3], dilation=dilations[3])
|
| 37 |
+
|
| 38 |
+
self.global_avg_pool = nn.Sequential(
|
| 39 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
| 40 |
+
nn.Conv2d(inplanes, output_features, 1, stride=1, bias=False),
|
| 41 |
+
nn.ReLU(inplace=True),
|
| 42 |
+
)
|
| 43 |
+
self.conv1 = nn.Conv2d(output_features * 5, output_features, 1, bias=False)
|
| 44 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 45 |
+
self.dropout = nn.Dropout(dropout)
|
| 46 |
+
|
| 47 |
+
def forward(self, x): # skipcq: PYL-W0221
|
| 48 |
+
x1 = self.aspp1(x)
|
| 49 |
+
x2 = self.aspp2(x)
|
| 50 |
+
x3 = self.aspp3(x)
|
| 51 |
+
x4 = self.aspp4(x)
|
| 52 |
+
x5 = self.global_avg_pool(x)
|
| 53 |
+
x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=False)
|
| 54 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
| 55 |
+
|
| 56 |
+
x = self.conv1(x)
|
| 57 |
+
x = self.relu1(x)
|
| 58 |
+
|
| 59 |
+
return self.dropout(x)
|
HybridViT/module/component/feature_extractor/addon_module/visual_attention.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from functools import reduce
|
| 5 |
+
|
| 6 |
+
__all__ = ['Adaptive_Global_Model', 'GlobalContext', 'SELayer', 'SKBlock', 'CBAM']
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def constant_init(module, val, bias=0):
|
| 10 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
| 11 |
+
nn.init.constant_(module.weight, val)
|
| 12 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 13 |
+
nn.init.constant_(module.bias, bias)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def xavier_init(module, gain=1, bias=0, distribution='normal'):
|
| 17 |
+
assert distribution in ['uniform', 'normal']
|
| 18 |
+
if distribution == 'uniform':
|
| 19 |
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
| 20 |
+
else:
|
| 21 |
+
nn.init.xavier_normal_(module.weight, gain=gain)
|
| 22 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 23 |
+
nn.init.constant_(module.bias, bias)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def normal_init(module, mean=0, std=1, bias=0):
|
| 27 |
+
nn.init.normal_(module.weight, mean, std)
|
| 28 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 29 |
+
nn.init.constant_(module.bias, bias)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def uniform_init(module, a=0, b=1, bias=0):
|
| 33 |
+
nn.init.uniform_(module.weight, a, b)
|
| 34 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 35 |
+
nn.init.constant_(module.bias, bias)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def kaiming_init(module,
|
| 39 |
+
a=0,
|
| 40 |
+
mode='fan_out',
|
| 41 |
+
nonlinearity='relu',
|
| 42 |
+
bias=0,
|
| 43 |
+
distribution='normal'):
|
| 44 |
+
assert distribution in ['uniform', 'normal']
|
| 45 |
+
if distribution == 'uniform':
|
| 46 |
+
nn.init.kaiming_uniform_(
|
| 47 |
+
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
| 48 |
+
else:
|
| 49 |
+
nn.init.kaiming_normal_(
|
| 50 |
+
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
| 51 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 52 |
+
nn.init.constant_(module.bias, bias)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def last_zero_init(m):
|
| 56 |
+
if isinstance(m, nn.Sequential):
|
| 57 |
+
constant_init(m[-1], val=0)
|
| 58 |
+
else:
|
| 59 |
+
constant_init(m, val=0)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class SELayer(nn.Module):
|
| 63 |
+
def __init__(self, channel, reduction=16, dropout=0.1):
|
| 64 |
+
super(SELayer, self).__init__()
|
| 65 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 66 |
+
self.fc = nn.Sequential(
|
| 67 |
+
nn.Linear(channel, channel // reduction),
|
| 68 |
+
nn.ReLU(inplace=True),
|
| 69 |
+
nn.Dropout(dropout),
|
| 70 |
+
nn.Linear(channel // reduction, channel),
|
| 71 |
+
nn.Sigmoid()
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
b, c, _, _ = x.size()
|
| 76 |
+
y = self.avg_pool(x).view(b, c)
|
| 77 |
+
y = self.fc(y).view(b, c, 1, 1)
|
| 78 |
+
return x + x * y.expand_as(x) # SE-Residual
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class BasicConv(nn.Module):
|
| 82 |
+
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
|
| 83 |
+
bn=True, bias=False):
|
| 84 |
+
super(BasicConv, self).__init__()
|
| 85 |
+
self.out_channels = out_planes
|
| 86 |
+
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
|
| 87 |
+
dilation=dilation, groups=groups, bias=bias)
|
| 88 |
+
self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
|
| 89 |
+
self.relu = nn.ReLU() if relu else None
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
x = self.conv(x)
|
| 93 |
+
if self.bn is not None:
|
| 94 |
+
x = self.bn(x)
|
| 95 |
+
if self.relu is not None:
|
| 96 |
+
x = self.relu(x)
|
| 97 |
+
return x
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class Flatten(nn.Module):
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
return x.view(x.size(0), -1)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class ChannelGate(nn.Module):
|
| 106 |
+
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
|
| 107 |
+
super(ChannelGate, self).__init__()
|
| 108 |
+
self.gate_channels = gate_channels
|
| 109 |
+
self.mlp = nn.Sequential(
|
| 110 |
+
Flatten(),
|
| 111 |
+
nn.Linear(gate_channels, gate_channels // reduction_ratio),
|
| 112 |
+
nn.ReLU(),
|
| 113 |
+
nn.Linear(gate_channels // reduction_ratio, gate_channels)
|
| 114 |
+
)
|
| 115 |
+
self.pool_types = pool_types
|
| 116 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 117 |
+
self.maxpool = nn.AdaptiveMaxPool2d(1)
|
| 118 |
+
self.sigmoid = nn.Sigmoid()
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
channel_att_sum = None
|
| 122 |
+
for pool_type in self.pool_types:
|
| 123 |
+
if pool_type == 'avg':
|
| 124 |
+
avg_pool = self.avgpool(x)
|
| 125 |
+
channel_att_raw = self.mlp(avg_pool)
|
| 126 |
+
elif pool_type == 'max':
|
| 127 |
+
max_pool = self.maxpool(x)
|
| 128 |
+
channel_att_raw = self.mlp(max_pool)
|
| 129 |
+
|
| 130 |
+
if channel_att_sum is None:
|
| 131 |
+
channel_att_sum = channel_att_raw
|
| 132 |
+
else:
|
| 133 |
+
channel_att_sum = channel_att_sum + channel_att_raw
|
| 134 |
+
|
| 135 |
+
scale = self.sigmoid(channel_att_sum).unsqueeze(2) \
|
| 136 |
+
.unsqueeze(3).expand_as(x)
|
| 137 |
+
return x * scale
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class ChannelPool(nn.Module):
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1) \
|
| 143 |
+
.unsqueeze(1)), dim=1)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class SpatialGate(nn.Module):
|
| 147 |
+
def __init__(self):
|
| 148 |
+
super(SpatialGate, self).__init__()
|
| 149 |
+
kernel_size = 7
|
| 150 |
+
self.compress = ChannelPool()
|
| 151 |
+
self.spatial = BasicConv(2, 1, kernel_size, stride=1 \
|
| 152 |
+
, padding=(kernel_size - 1) // 2, relu=False)
|
| 153 |
+
self.sigmoid = nn.Sigmoid()
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
x_compress = self.compress(x)
|
| 157 |
+
x_out = self.spatial(x_compress)
|
| 158 |
+
scale = self.sigmoid(x_out) # broadcasting
|
| 159 |
+
return x * scale
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class CBAM(nn.Module):
|
| 163 |
+
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
|
| 164 |
+
super(CBAM, self).__init__()
|
| 165 |
+
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
|
| 166 |
+
self.no_spatial = no_spatial
|
| 167 |
+
if not no_spatial:
|
| 168 |
+
self.SpatialGate = SpatialGate()
|
| 169 |
+
|
| 170 |
+
def forward(self, x):
|
| 171 |
+
x_out = self.ChannelGate(x)
|
| 172 |
+
if not self.no_spatial:
|
| 173 |
+
x_out = self.SpatialGate(x_out)
|
| 174 |
+
return x_out
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class SKBlock(nn.Module):
|
| 178 |
+
def __init__(self, in_channels, out_channels, stride=1, M=2, r=16, L=32):
|
| 179 |
+
super(SKBlock, self).__init__()
|
| 180 |
+
d = max(in_channels // r, L)
|
| 181 |
+
self.M = M
|
| 182 |
+
self.out_channels = out_channels
|
| 183 |
+
self.conv = nn.ModuleList()
|
| 184 |
+
for i in range(M):
|
| 185 |
+
self.conv.append(nn.Sequential(
|
| 186 |
+
nn.Conv2d(in_channels, out_channels, 3, stride, padding=1 + i, dilation=1 + i, groups=32, bias=False),
|
| 187 |
+
nn.BatchNorm2d(out_channels),
|
| 188 |
+
nn.ReLU(inplace=True)))
|
| 189 |
+
self.global_pool = nn.AdaptiveAvgPool2d(1)
|
| 190 |
+
self.fc1 = nn.Sequential(nn.Conv2d(out_channels, d, 1, bias=False),
|
| 191 |
+
nn.BatchNorm2d(d),
|
| 192 |
+
nn.ReLU(inplace=True))
|
| 193 |
+
self.fc2 = nn.Conv2d(d, out_channels * M, 1, 1, bias=False)
|
| 194 |
+
self.softmax = nn.Softmax(dim=1)
|
| 195 |
+
|
| 196 |
+
def forward(self, input):
|
| 197 |
+
batch_size = input.size(0)
|
| 198 |
+
output = []
|
| 199 |
+
# the part of split
|
| 200 |
+
for i, conv in enumerate(self.conv):
|
| 201 |
+
# print(i,conv(input).size())
|
| 202 |
+
output.append(conv(input))
|
| 203 |
+
# the part of fusion
|
| 204 |
+
U = reduce(lambda x, y: x + y, output)
|
| 205 |
+
s = self.global_pool(U)
|
| 206 |
+
z = self.fc1(s)
|
| 207 |
+
a_b = self.fc2(z)
|
| 208 |
+
a_b = a_b.reshape(batch_size, self.M, self.out_channels, -1)
|
| 209 |
+
a_b = self.softmax(a_b)
|
| 210 |
+
# the part of selection
|
| 211 |
+
a_b = list(a_b.chunk(self.M, dim=1)) # split to a and b
|
| 212 |
+
a_b = list(map(lambda x: x.reshape(batch_size, self.out_channels, 1, 1), a_b))
|
| 213 |
+
V = list(map(lambda x, y: x * y, output, a_b))
|
| 214 |
+
V = reduce(lambda x, y: x + y, V)
|
| 215 |
+
return V
|
| 216 |
+
|
| 217 |
+
def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
|
| 218 |
+
min_value = min_value or divisor
|
| 219 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
| 220 |
+
# Make sure that round down does not go down by more than 10%.
|
| 221 |
+
if new_v < round_limit * v:
|
| 222 |
+
new_v += divisor
|
| 223 |
+
return new_v
|
| 224 |
+
|
| 225 |
+
class LayerNorm2d(nn.LayerNorm):
|
| 226 |
+
""" LayerNorm for channels of '2D' spatial BCHW tensors """
|
| 227 |
+
def __init__(self, num_channels):
|
| 228 |
+
super().__init__(num_channels)
|
| 229 |
+
|
| 230 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 231 |
+
return F.layer_norm(
|
| 232 |
+
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
| 233 |
+
|
| 234 |
+
class ConvMLP(nn.Module):
|
| 235 |
+
def __init__(self, in_channels, out_channels=None, hidden_channels=None, drop=0.25):
|
| 236 |
+
super().__init__()
|
| 237 |
+
out_channels = in_channels or out_channels
|
| 238 |
+
hidden_channels = in_channels or hidden_channels
|
| 239 |
+
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=True)
|
| 240 |
+
self.norm = LayerNorm2d(hidden_channels)
|
| 241 |
+
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=True)
|
| 242 |
+
self.act = nn.ReLU()
|
| 243 |
+
self.drop = nn.Dropout(drop)
|
| 244 |
+
|
| 245 |
+
def forward(self, x):
|
| 246 |
+
x = self.fc1(x)
|
| 247 |
+
x = self.norm(x)
|
| 248 |
+
x = self.act(x)
|
| 249 |
+
x = self.drop(x)
|
| 250 |
+
x = self.fc2(x)
|
| 251 |
+
return x
|
| 252 |
+
|
| 253 |
+
class GlobalContext(nn.Module):
|
| 254 |
+
def __init__(self,
|
| 255 |
+
channel,
|
| 256 |
+
use_attn=True,
|
| 257 |
+
fuse_add=True,
|
| 258 |
+
fuse_scale=False,
|
| 259 |
+
rd_ratio=1./8,
|
| 260 |
+
rd_channels=None
|
| 261 |
+
):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.use_attn = use_attn
|
| 264 |
+
self.global_cxt = nn.Conv2d(channel, 1, kernel_size=1, bias=True) if use_attn else nn.AdaptiveAvgPool2d(1)
|
| 265 |
+
|
| 266 |
+
if rd_channels is None:
|
| 267 |
+
rd_channels = make_divisible(channel*rd_ratio, divisor=1, round_limit=0.)
|
| 268 |
+
|
| 269 |
+
if fuse_add:
|
| 270 |
+
self.bottleneck_add = ConvMLP(channel, hidden_channels=rd_channels)
|
| 271 |
+
else:
|
| 272 |
+
self.bottleneck_add = None
|
| 273 |
+
if fuse_scale:
|
| 274 |
+
self.bottleneck_mul = ConvMLP(channel, hidden_channels=rd_channels)
|
| 275 |
+
else:
|
| 276 |
+
self.bottleneck_mul = None
|
| 277 |
+
|
| 278 |
+
self.init_weight()
|
| 279 |
+
|
| 280 |
+
def init_weight(self):
|
| 281 |
+
if self.use_attn:
|
| 282 |
+
nn.init.kaiming_normal_(self.global_cxt.weight, mode='fan_in', nonlinearity='relu')
|
| 283 |
+
if self.bottleneck_add is not None:
|
| 284 |
+
nn.init.zeros_(self.bottleneck_add.fc2.weight)
|
| 285 |
+
if self.bottleneck_mul is not None:
|
| 286 |
+
nn.init.zeros_(self.bottleneck_mul.fc2.weight)
|
| 287 |
+
|
| 288 |
+
def forward(self, x):
|
| 289 |
+
B, C, H, W = x.shape
|
| 290 |
+
if self.use_attn:
|
| 291 |
+
attn = self.global_cxt(x).reshape(B, 1, H*W).squeeze(1)
|
| 292 |
+
attn = F.softmax(attn, dim=-1).unsqueeze(-1) #shape BxH*Wx1
|
| 293 |
+
query = x.reshape(B, C, H*W) #shape BxCxH*W
|
| 294 |
+
glob_cxt = torch.bmm(query, attn).unsqueeze(-1)
|
| 295 |
+
else:
|
| 296 |
+
glob_cxt = self.global_cxt(x)
|
| 297 |
+
assert len(glob_cxt.shape) == 4
|
| 298 |
+
|
| 299 |
+
if self.bottleneck_add is not None:
|
| 300 |
+
x_trans = self.bottleneck_add(glob_cxt)
|
| 301 |
+
x_fuse = x + x_trans
|
| 302 |
+
if self.bottleneck_mul is not None:
|
| 303 |
+
x_trans = F.sigmoid(self.bottleneck_mul(glob_cxt))
|
| 304 |
+
x_fuse = x*x_trans
|
| 305 |
+
|
| 306 |
+
return x_fuse
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class Adaptive_Global_Model(nn.Module):
|
| 310 |
+
def __init__(self, inplanes, factor=2, ratio=0.0625, dropout=0.1):
|
| 311 |
+
super(Adaptive_Global_Model, self).__init__()
|
| 312 |
+
# b, w, h, c => gc_block (b, w, h, c) => => b, w, inplanes
|
| 313 |
+
self.embedding = nn.Linear(inplanes * factor, inplanes)
|
| 314 |
+
self.gc_block = GlobalContext(inplanes, ratio=ratio) #
|
| 315 |
+
self.dropout = nn.Dropout(dropout)
|
| 316 |
+
|
| 317 |
+
def forward(self, x):
|
| 318 |
+
x = self.gc_block(x) # BCHW => BCHW
|
| 319 |
+
x = x.permute(0, 3, 1, 2) # BCHW => BWCH
|
| 320 |
+
b, w, _, _ = x.shape
|
| 321 |
+
x = x.contiguous().view(b, w, -1)
|
| 322 |
+
x = self.embedding(x) # B W C
|
| 323 |
+
x = self.dropout(x)
|
| 324 |
+
return x
|
| 325 |
+
|
HybridViT/module/component/feature_extractor/clova_impl/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .resnet import ResNet_FeatureExtractor
|
| 2 |
+
from .vgg import VGG_FeatureExtractor
|
HybridViT/module/component/feature_extractor/clova_impl/resnet.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from ..addon_module.visual_attention import GlobalContext
|
| 7 |
+
from .....helper import clean_state_dict
|
| 8 |
+
|
| 9 |
+
class BasicBlock(nn.Module):
|
| 10 |
+
expansion = 1
|
| 11 |
+
|
| 12 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 13 |
+
super(BasicBlock, self).__init__()
|
| 14 |
+
self.conv1 = self._conv3x3(inplanes, planes)
|
| 15 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 16 |
+
self.conv2 = self._conv3x3(planes, planes)
|
| 17 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 18 |
+
self.relu = nn.ReLU(inplace=True)
|
| 19 |
+
self.downsample = downsample
|
| 20 |
+
self.stride = stride
|
| 21 |
+
|
| 22 |
+
def _conv3x3(self, in_planes, out_planes, stride=1):
|
| 23 |
+
"3x3 convolution with padding"
|
| 24 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 25 |
+
padding=1, bias=False)
|
| 26 |
+
|
| 27 |
+
def zero_init_last_bn(self):
|
| 28 |
+
nn.init.zeros_(self.bn2.weight)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
residual = x
|
| 32 |
+
|
| 33 |
+
out = self.conv1(x)
|
| 34 |
+
out = self.bn1(out)
|
| 35 |
+
out = self.relu(out)
|
| 36 |
+
|
| 37 |
+
out = self.conv2(out)
|
| 38 |
+
out = self.bn2(out)
|
| 39 |
+
|
| 40 |
+
if self.downsample is not None:
|
| 41 |
+
residual = self.downsample(x)
|
| 42 |
+
|
| 43 |
+
out += residual
|
| 44 |
+
out = self.relu(out)
|
| 45 |
+
|
| 46 |
+
return out
|
| 47 |
+
|
| 48 |
+
class ResNet(nn.Module):
|
| 49 |
+
def __init__(self, input_channel, output_channel, block, layers, with_gcb=True, debug=False, zero_init_last_bn=False):
|
| 50 |
+
super(ResNet, self).__init__()
|
| 51 |
+
self.with_gcb = with_gcb
|
| 52 |
+
|
| 53 |
+
self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
|
| 54 |
+
self.inplanes = int(output_channel / 8)
|
| 55 |
+
|
| 56 |
+
self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
|
| 57 |
+
kernel_size=3, stride=1, padding=1, bias=False)
|
| 58 |
+
self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
|
| 59 |
+
|
| 60 |
+
self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
|
| 61 |
+
kernel_size=3, stride=1, padding=1, bias=False)
|
| 62 |
+
self.bn0_2 = nn.BatchNorm2d(self.inplanes)
|
| 63 |
+
self.relu = nn.ReLU(inplace=True)
|
| 64 |
+
|
| 65 |
+
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
| 66 |
+
self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
|
| 67 |
+
self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
|
| 68 |
+
0], kernel_size=3, stride=1, padding=1, bias=False)
|
| 69 |
+
self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
|
| 70 |
+
|
| 71 |
+
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
| 72 |
+
self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
|
| 73 |
+
self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
|
| 74 |
+
1], kernel_size=3, stride=1, padding=1, bias=False)
|
| 75 |
+
self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
|
| 76 |
+
|
| 77 |
+
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
|
| 78 |
+
self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
|
| 79 |
+
self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
|
| 80 |
+
2], kernel_size=3, stride=1, padding=1, bias=False)
|
| 81 |
+
self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
|
| 82 |
+
|
| 83 |
+
self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
|
| 84 |
+
|
| 85 |
+
self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
|
| 86 |
+
3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
|
| 87 |
+
self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
|
| 88 |
+
|
| 89 |
+
self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
|
| 90 |
+
3], kernel_size=2, stride=1, padding=0, bias=False)
|
| 91 |
+
self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
|
| 92 |
+
|
| 93 |
+
self.init_weights(zero_init_last_bn=zero_init_last_bn)
|
| 94 |
+
self.debug = debug
|
| 95 |
+
|
| 96 |
+
def zero_init_last_bn(self):
|
| 97 |
+
nn.init.zeros_(self.bn4_2.weight)
|
| 98 |
+
|
| 99 |
+
def init_weights(self, zero_init_last_bn=True):
|
| 100 |
+
initialized = ['global_cxt', 'bottleneck_add', 'bottleneck_mul']
|
| 101 |
+
for n, m in self.named_modules():
|
| 102 |
+
if any([d in n for d in initialized]):
|
| 103 |
+
continue
|
| 104 |
+
elif isinstance(m, nn.Conv2d):
|
| 105 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 106 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 107 |
+
nn.init.ones_(m.weight)
|
| 108 |
+
nn.init.zeros_(m.bias)
|
| 109 |
+
if zero_init_last_bn:
|
| 110 |
+
for m in self.modules():
|
| 111 |
+
if hasattr(m, 'zero_init_last_bn'):
|
| 112 |
+
m.zero_init_last_bn()
|
| 113 |
+
|
| 114 |
+
def _make_layer(self, block, planes, blocks, with_gcb=False, stride=1):
|
| 115 |
+
downsample = None
|
| 116 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 117 |
+
downsample = nn.Sequential(
|
| 118 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 119 |
+
kernel_size=1, stride=stride, bias=False),
|
| 120 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
layers = []
|
| 124 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 125 |
+
self.inplanes = planes * block.expansion
|
| 126 |
+
|
| 127 |
+
for i in range(1, blocks):
|
| 128 |
+
layers.append(block(self.inplanes, planes))
|
| 129 |
+
|
| 130 |
+
if self.with_gcb:
|
| 131 |
+
layers.append(GlobalContext(planes))
|
| 132 |
+
|
| 133 |
+
return nn.Sequential(*layers)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
if self.debug:
|
| 137 |
+
print('input shape', x.shape)
|
| 138 |
+
|
| 139 |
+
x = self.conv0_1(x)
|
| 140 |
+
x = self.bn0_1(x)
|
| 141 |
+
x = self.relu(x)
|
| 142 |
+
|
| 143 |
+
if self.debug:
|
| 144 |
+
print('conv1 shape', x.shape)
|
| 145 |
+
|
| 146 |
+
x = self.conv0_2(x)
|
| 147 |
+
x = self.bn0_2(x)
|
| 148 |
+
x = self.relu(x)
|
| 149 |
+
|
| 150 |
+
if self.debug:
|
| 151 |
+
print('conv2 shape', x.shape)
|
| 152 |
+
|
| 153 |
+
x = self.maxpool1(x)
|
| 154 |
+
|
| 155 |
+
if self.debug:
|
| 156 |
+
print('pool1 shape', x.shape)
|
| 157 |
+
|
| 158 |
+
x = self.layer1(x)
|
| 159 |
+
|
| 160 |
+
if self.debug:
|
| 161 |
+
print('block1 shape', x.shape)
|
| 162 |
+
|
| 163 |
+
x = self.conv1(x)
|
| 164 |
+
x = self.bn1(x)
|
| 165 |
+
x = self.relu(x)
|
| 166 |
+
|
| 167 |
+
if self.debug:
|
| 168 |
+
print('conv3 shape', x.shape)
|
| 169 |
+
|
| 170 |
+
x = self.maxpool2(x)
|
| 171 |
+
|
| 172 |
+
if self.debug:
|
| 173 |
+
print('pool2 shape', x.shape)
|
| 174 |
+
|
| 175 |
+
x = self.layer2(x)
|
| 176 |
+
|
| 177 |
+
if self.debug:
|
| 178 |
+
print('block2 shape', x.shape)
|
| 179 |
+
|
| 180 |
+
x = self.conv2(x)
|
| 181 |
+
x = self.bn2(x)
|
| 182 |
+
x = self.relu(x)
|
| 183 |
+
|
| 184 |
+
if self.debug:
|
| 185 |
+
print('conv4 shape', x.shape)
|
| 186 |
+
|
| 187 |
+
x = self.maxpool3(x)
|
| 188 |
+
|
| 189 |
+
if self.debug:
|
| 190 |
+
print('pool3 shape', x.shape)
|
| 191 |
+
|
| 192 |
+
x = self.layer3(x)
|
| 193 |
+
|
| 194 |
+
if self.debug:
|
| 195 |
+
print('block3 shape', x.shape)
|
| 196 |
+
|
| 197 |
+
x = self.conv3(x)
|
| 198 |
+
x = self.bn3(x)
|
| 199 |
+
x = self.relu(x)
|
| 200 |
+
|
| 201 |
+
if self.debug:
|
| 202 |
+
print('conv5 shape', x.shape)
|
| 203 |
+
|
| 204 |
+
x = self.layer4(x)
|
| 205 |
+
|
| 206 |
+
if self.debug:
|
| 207 |
+
print('block4 shape', x.shape)
|
| 208 |
+
|
| 209 |
+
x = self.conv4_1(x)
|
| 210 |
+
x = self.bn4_1(x)
|
| 211 |
+
x = self.relu(x)
|
| 212 |
+
|
| 213 |
+
if self.debug:
|
| 214 |
+
print('conv6 shape', x.shape)
|
| 215 |
+
|
| 216 |
+
x = self.conv4_2(x)
|
| 217 |
+
x = self.bn4_2(x)
|
| 218 |
+
x = self.relu(x)
|
| 219 |
+
|
| 220 |
+
if self.debug:
|
| 221 |
+
print('conv7 shape', x.shape)
|
| 222 |
+
|
| 223 |
+
return x
|
| 224 |
+
|
| 225 |
+
class ResNet_FeatureExtractor(nn.Module):
|
| 226 |
+
""" FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
|
| 227 |
+
|
| 228 |
+
def __init__(self, input_channel=3, output_channel=512, gcb=False, pretrained=False, weight_dir=None, debug=False):
|
| 229 |
+
super(ResNet_FeatureExtractor, self).__init__()
|
| 230 |
+
self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3], gcb, debug)
|
| 231 |
+
self.in_chans = input_channel
|
| 232 |
+
if pretrained:
|
| 233 |
+
assert weight_dir is not None
|
| 234 |
+
self.load_pretrained(weight_dir)
|
| 235 |
+
|
| 236 |
+
def forward(self, input):
|
| 237 |
+
output = self.ConvNet(input)
|
| 238 |
+
return output
|
| 239 |
+
|
| 240 |
+
def load_pretrained(self, weight_dir):
|
| 241 |
+
state_dict: OrderedDict = torch.load(weight_dir)
|
| 242 |
+
cleaned_state_dict = clean_state_dict(state_dict)
|
| 243 |
+
new_state_dict = OrderedDict()
|
| 244 |
+
name: str
|
| 245 |
+
param: torch.FloatTensor
|
| 246 |
+
for name, param in cleaned_state_dict.items():
|
| 247 |
+
if name.startswith('FeatureExtraction'):
|
| 248 |
+
output_name = name.replace('FeatureExtraction.', '')
|
| 249 |
+
if output_name == 'ConvNet.conv0_1.weight':
|
| 250 |
+
print('Old', param.shape)
|
| 251 |
+
new_param = param.repeat(1, self.in_chans, 1, 1)
|
| 252 |
+
print('New', new_param.shape)
|
| 253 |
+
else: new_param = param
|
| 254 |
+
new_state_dict[output_name] = new_param
|
| 255 |
+
print("=> Loading pretrained weight for ResNet backbone")
|
| 256 |
+
self.load_state_dict(new_state_dict)
|
| 257 |
+
|
| 258 |
+
if __name__ == '__main__':
|
| 259 |
+
model = ResNet_FeatureExtractor(input_channel=1, debug=True)
|
| 260 |
+
a = torch.rand(1, 1, 128, 480)
|
| 261 |
+
output = model(a)
|
| 262 |
+
print(output.shape)
|
HybridViT/module/component/feature_extractor/clova_impl/vgg.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
class VGG_FeatureExtractor(nn.Module):
|
| 5 |
+
""" FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """
|
| 6 |
+
|
| 7 |
+
def __init__(self, input_channel, output_channel=512):
|
| 8 |
+
super(VGG_FeatureExtractor, self).__init__()
|
| 9 |
+
self.output_channel = [int(output_channel / 8), int(output_channel / 4),
|
| 10 |
+
int(output_channel / 2), output_channel] # [64, 128, 256, 512]
|
| 11 |
+
self.ConvNet = nn.Sequential(
|
| 12 |
+
nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
|
| 13 |
+
nn.MaxPool2d(2, 2), # 64x16x50
|
| 14 |
+
nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True),
|
| 15 |
+
nn.MaxPool2d(2, 2), # 128x8x25
|
| 16 |
+
nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25
|
| 17 |
+
nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
|
| 18 |
+
nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
|
| 19 |
+
nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False),
|
| 20 |
+
nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25
|
| 21 |
+
nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False),
|
| 22 |
+
nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
|
| 23 |
+
nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
|
| 24 |
+
nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24
|
| 25 |
+
|
| 26 |
+
def forward(self, input):
|
| 27 |
+
return self.ConvNet(input)
|
HybridViT/module/component/feature_extractor/helpers.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
# Calculate symmetric padding for a convolution
|
| 7 |
+
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
|
| 8 |
+
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
| 9 |
+
return padding
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
|
| 13 |
+
def get_same_padding(x: int, k: int, s: int, d: int):
|
| 14 |
+
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Can SAME padding for given args be done statically?
|
| 18 |
+
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
|
| 19 |
+
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Dynamically pad input x with 'SAME' padding for conv with specified args
|
| 23 |
+
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
|
| 24 |
+
ih, iw = x.size()[-2:]
|
| 25 |
+
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
|
| 26 |
+
if pad_h > 0 or pad_w > 0:
|
| 27 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
|
| 28 |
+
return x
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
|
| 32 |
+
dynamic = False
|
| 33 |
+
if isinstance(padding, str):
|
| 34 |
+
# for any string padding, the padding will be calculated for you, one of three ways
|
| 35 |
+
padding = padding.lower()
|
| 36 |
+
if padding == 'same':
|
| 37 |
+
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
| 38 |
+
if is_static_pad(kernel_size, **kwargs):
|
| 39 |
+
# static case, no extra overhead
|
| 40 |
+
padding = get_padding(kernel_size, **kwargs)
|
| 41 |
+
else:
|
| 42 |
+
# dynamic 'SAME' padding, has runtime/GPU memory overhead
|
| 43 |
+
padding = 0
|
| 44 |
+
dynamic = True
|
| 45 |
+
elif padding == 'valid':
|
| 46 |
+
# 'VALID' padding, same as padding=0
|
| 47 |
+
padding = 0
|
| 48 |
+
else:
|
| 49 |
+
# Default to PyTorch style 'same'-ish symmetric padding
|
| 50 |
+
padding = get_padding(kernel_size, **kwargs)
|
| 51 |
+
return padding, dynamic
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def adapt_input_conv(in_chans, conv_weight):
|
| 55 |
+
conv_type = conv_weight.dtype
|
| 56 |
+
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
|
| 57 |
+
O, I, J, K = conv_weight.shape
|
| 58 |
+
if in_chans == 1:
|
| 59 |
+
if I > 3:
|
| 60 |
+
assert conv_weight.shape[1] % 3 == 0
|
| 61 |
+
# For models with space2depth stems
|
| 62 |
+
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
|
| 63 |
+
conv_weight = conv_weight.sum(dim=2, keepdim=False)
|
| 64 |
+
else:
|
| 65 |
+
conv_weight = conv_weight.sum(dim=1, keepdim=True)
|
| 66 |
+
elif in_chans != 3:
|
| 67 |
+
if I != 3:
|
| 68 |
+
raise NotImplementedError('Weight format not supported by conversion.')
|
| 69 |
+
else:
|
| 70 |
+
# NOTE this strategy should be better than random init, but there could be other combinations of
|
| 71 |
+
# the original RGB input layer weights that'd work better for specific cases.
|
| 72 |
+
repeat = int(math.ceil(in_chans / 3))
|
| 73 |
+
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
| 74 |
+
conv_weight *= (3 / float(in_chans))
|
| 75 |
+
conv_weight = conv_weight.to(conv_type)
|
| 76 |
+
return conv_weight
|
HybridViT/module/component/feature_extractor/vgg.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
__all__ = ['vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn']
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class VGG(nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, features, num_channel_out=512, init_weights=True):
|
| 10 |
+
super(VGG, self).__init__()
|
| 11 |
+
self.features = features
|
| 12 |
+
self.num_out_features = 512
|
| 13 |
+
|
| 14 |
+
self.lastlayer = nn.Sequential(
|
| 15 |
+
nn.Conv2d(self.num_out_features, num_channel_out, kernel_size=1, stride=1, padding=0, groups=32, bias=False),
|
| 16 |
+
nn.BatchNorm2d(num_channel_out),
|
| 17 |
+
nn.ReLU(inplace=True),
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
if init_weights:
|
| 21 |
+
self._initialize_weights()
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
x = self.features(x)
|
| 25 |
+
x = self.lastlayer(x)
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
def _initialize_weights(self):
|
| 29 |
+
for m in self.modules():
|
| 30 |
+
if isinstance(m, nn.Conv2d):
|
| 31 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 32 |
+
if m.bias is not None:
|
| 33 |
+
nn.init.constant_(m.bias, 0)
|
| 34 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 35 |
+
nn.init.constant_(m.weight, 1)
|
| 36 |
+
nn.init.constant_(m.bias, 0)
|
| 37 |
+
elif isinstance(m, nn.Linear):
|
| 38 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 39 |
+
nn.init.constant_(m.bias, 0)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def make_layers(cfg, down_sample=8, batch_norm=False):
|
| 43 |
+
layers = []
|
| 44 |
+
in_channels = 3
|
| 45 |
+
for v in cfg:
|
| 46 |
+
if v == 'M':
|
| 47 |
+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
| 48 |
+
elif isinstance(v, dict):
|
| 49 |
+
cur_size = v[down_sample]
|
| 50 |
+
layers += [nn.MaxPool2d(kernel_size=cur_size, stride=cur_size)]
|
| 51 |
+
else:
|
| 52 |
+
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
|
| 53 |
+
if batch_norm:
|
| 54 |
+
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
| 55 |
+
else:
|
| 56 |
+
layers += [conv2d, nn.ReLU(inplace=True)]
|
| 57 |
+
in_channels = v
|
| 58 |
+
return nn.Sequential(*layers)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
cfgs = {
|
| 62 |
+
'A': [64, 'M', 128, 'M', 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, {4: (2, 1), 8: (2, 1)}, 512, 512,
|
| 63 |
+
{4: (2, 1), 8: (2, 1)}],
|
| 64 |
+
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, {4: (2, 1), 8: (2, 1)}, 512, 512,
|
| 65 |
+
{4: (2, 1), 8: (2, 1)}, ],
|
| 66 |
+
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, 512, {4: (2, 1), 8: (2, 1)}, 512,
|
| 67 |
+
512, 512, {4: (2, 1), 8: (2, 1)}, ],
|
| 68 |
+
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, {4: (2, 1), 8: (2, 2)}, 512, 512, 512, 512,
|
| 69 |
+
{4: (2, 1), 8: (2, 1)}, 512, 512, 512, 512, {4: (2, 1), 8: (2, 1)}, ],
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _vgg(model_path, cfg, batch_norm, pretrained, progress, num_channel_out, down_sample, **kwargs):
|
| 74 |
+
if pretrained:
|
| 75 |
+
kwargs['init_weights'] = False
|
| 76 |
+
model = VGG(make_layers(cfgs[cfg], down_sample, batch_norm=batch_norm), num_channel_out, **kwargs)
|
| 77 |
+
if model_path and pretrained:
|
| 78 |
+
state_dict = torch.load(model_path)
|
| 79 |
+
model.load_state_dict(state_dict, strict=False)
|
| 80 |
+
return model
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def vgg11_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs):
|
| 84 |
+
return _vgg(model_path, 'A', True, pretrained, progress, num_channel_out, down_sample, **kwargs)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def vgg13_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs):
|
| 88 |
+
return _vgg(model_path, 'B', True, pretrained, progress, num_channel_out, down_sample, **kwargs)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def vgg16_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs):
|
| 92 |
+
return _vgg(model_path, 'D', True, pretrained, progress, num_channel_out, down_sample, **kwargs)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def vgg19_bn(model_path='', num_channel_out=512, down_sample=8, pretrained=True, progress=True, **kwargs):
|
| 96 |
+
return _vgg(model_path, 'E', True, pretrained, progress, num_channel_out, down_sample, **kwargs)
|
HybridViT/module/component/prediction_head/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .seq2seq import Attention
|
| 2 |
+
from .seq2seq_v2 import AttentionV2
|
| 3 |
+
from .tfm import TransformerPrediction
|
| 4 |
+
|
| 5 |
+
__all__ = ['Attention', 'AttentionV2', 'TransformerPrediction']
|
HybridViT/module/component/prediction_head/addon_module/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .attention1D import *
|
| 2 |
+
from .attention2D import *
|
| 3 |
+
from .position_encoding import *
|
HybridViT/module/component/prediction_head/addon_module/attention1D.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
torch.autograd.set_detect_anomaly(True)
|
| 6 |
+
|
| 7 |
+
class LuongAttention(nn.Module):
|
| 8 |
+
def __init__(self, input_size, hidden_size, num_embeddings, num_classes, method='dot'):
|
| 9 |
+
super(LuongAttention, self).__init__()
|
| 10 |
+
self.attn = LuongAttentionCell(hidden_size, method)
|
| 11 |
+
self.rnn = nn.LSTMCell(num_embeddings, hidden_size)
|
| 12 |
+
self.hidden_size = hidden_size
|
| 13 |
+
self.generator = nn.Linear(2*hidden_size, num_classes)
|
| 14 |
+
|
| 15 |
+
def forward(self, prev_hidden, batch_H, embed_text):
|
| 16 |
+
hidden = self.rnn(embed_text, prev_hidden)
|
| 17 |
+
|
| 18 |
+
e = self.attn(hidden[0], batch_H)
|
| 19 |
+
# print('Shape e', e.shape)
|
| 20 |
+
alpha = F.softmax(e, dim=1)
|
| 21 |
+
# print('Shape al', alpha.shape)
|
| 22 |
+
|
| 23 |
+
context = torch.bmm(alpha.unsqueeze(1), batch_H).squeeze(1) # batch_size x num_channel
|
| 24 |
+
output = torch.cat([context, hidden[0]], 1) # batch_size x (num_channel + num_embedding)
|
| 25 |
+
output = torch.tanh(output)
|
| 26 |
+
output = self.generator(output)
|
| 27 |
+
|
| 28 |
+
return output, hidden, alpha
|
| 29 |
+
|
| 30 |
+
class LuongAttentionCell(nn.Module):
|
| 31 |
+
def __init__(self, hidden_size, method='dot'):
|
| 32 |
+
super(LuongAttentionCell, self).__init__()
|
| 33 |
+
self.method = method
|
| 34 |
+
self.hidden_size = hidden_size
|
| 35 |
+
|
| 36 |
+
# Defining the layers/weights required depending on alignment scoring method
|
| 37 |
+
if method == "general":
|
| 38 |
+
self.fc = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 39 |
+
|
| 40 |
+
elif method == "concat":
|
| 41 |
+
self.fc = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 42 |
+
self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size))
|
| 43 |
+
|
| 44 |
+
def forward(self, decoder_hidden, encoder_outputs):
|
| 45 |
+
decoder_hidden = decoder_hidden.unsqueeze(1)
|
| 46 |
+
# print('shape', decoder_hidden.shape)
|
| 47 |
+
|
| 48 |
+
if self.method == "dot":
|
| 49 |
+
# For the dot scoring method, no weights or linear layers are involved
|
| 50 |
+
return encoder_outputs.bmm(decoder_hidden.permute(0, 2, 1)).squeeze(-1)
|
| 51 |
+
|
| 52 |
+
elif self.method == "general":
|
| 53 |
+
# For general scoring, decoder hidden state is passed through linear layers to introduce a weight matrix
|
| 54 |
+
out = self.fc(decoder_hidden)
|
| 55 |
+
return encoder_outputs.bmm(out.permute(0, 2 , 1)).squeeze(-1)
|
| 56 |
+
|
| 57 |
+
elif self.method == "concat":
|
| 58 |
+
# For concat scoring, decoder hidden state and encoder outputs are concatenated first
|
| 59 |
+
out = torch.tanh(self.fc(decoder_hidden+encoder_outputs))
|
| 60 |
+
# print('Shape', out.shape)
|
| 61 |
+
return out.bmm(self.weight.unsqueeze(-1).repeat(out.shape[0], 1, 1)).squeeze(-1)
|
| 62 |
+
|
| 63 |
+
class BahdanauAttentionCell(nn.Module):
|
| 64 |
+
def __init__(self, input_dim, hidden_dim):
|
| 65 |
+
super(BahdanauAttentionCell, self).__init__()
|
| 66 |
+
self.i2h = nn.Linear(input_dim, hidden_dim, bias=False)
|
| 67 |
+
self.h2h = nn.Linear(hidden_dim, hidden_dim)
|
| 68 |
+
self.score = nn.Linear(hidden_dim, 1, bias=False)
|
| 69 |
+
|
| 70 |
+
def forward(self, decoder_hidden, encoder_output):
|
| 71 |
+
encoder_proj = self.i2h(encoder_output)
|
| 72 |
+
hidden_proj = self.h2h(decoder_hidden[0]).unsqueeze(1)
|
| 73 |
+
score = self.score(torch.tanh(encoder_proj + hidden_proj))
|
| 74 |
+
return score
|
| 75 |
+
|
| 76 |
+
class BahdanauAttention(nn.Module):
|
| 77 |
+
def __init__(self, input_size=100, hidden_size=256, num_embeddings=10, num_classes=10):
|
| 78 |
+
super(BahdanauAttention, self).__init__()
|
| 79 |
+
self.attn = BahdanauAttentionCell(input_size, hidden_size)
|
| 80 |
+
self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
|
| 81 |
+
self.input_size = input_size
|
| 82 |
+
self.hidden_size = hidden_size
|
| 83 |
+
self.generator = nn.Linear(hidden_size, num_classes)
|
| 84 |
+
|
| 85 |
+
def set_mem(self, prev_attn):
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
def reset_mem(self):
|
| 89 |
+
pass
|
| 90 |
+
|
| 91 |
+
def forward(self, prev_hidden, batch_H, embed_text):
|
| 92 |
+
# [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
|
| 93 |
+
e = self.attn(prev_hidden, batch_H)
|
| 94 |
+
alpha = F.softmax(e, dim=1)
|
| 95 |
+
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
|
| 96 |
+
concat_context = torch.cat([context, embed_text], 1) # batch_size x (num_channel + num_embedding)
|
| 97 |
+
cur_hidden = self.rnn(concat_context, prev_hidden)
|
| 98 |
+
output = self.generator(cur_hidden[0])
|
| 99 |
+
|
| 100 |
+
return output, cur_hidden, alpha
|
| 101 |
+
|
| 102 |
+
class ConstituentAttentionCell(nn.Module):
|
| 103 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 104 |
+
super().__init__(*args, **kwargs)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class ConstituentCoverageAttentionCell(ConstituentAttentionCell):
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
class LocationAwareAttentionCell(nn.Module):
|
| 111 |
+
def __init__(self, kernel_size, kernel_dim, hidden_dim, input_dim):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.loc_conv = nn.Conv1d(1, kernel_dim, kernel_size=2*kernel_size+1, padding=kernel_size, bias=True)
|
| 114 |
+
self.loc_proj = nn.Linear(kernel_dim, hidden_dim)
|
| 115 |
+
self.query_proj = nn.Linear(hidden_dim, hidden_dim)
|
| 116 |
+
self.key_proj = nn.Linear(input_dim, hidden_dim)
|
| 117 |
+
self.score = nn.Linear(hidden_dim, 1)
|
| 118 |
+
|
| 119 |
+
def forward(self, decoder_hidden, encoder_output, last_alignment):
|
| 120 |
+
batch_size, seq_length, hidden_dim = encoder_output.shape[0], encoder_output.shape[1], decoder_hidden[0].shape[1]
|
| 121 |
+
|
| 122 |
+
encoder_proj = self.key_proj(encoder_output)
|
| 123 |
+
hidden_proj = self.query_proj(decoder_hidden[0]).unsqueeze(1)
|
| 124 |
+
|
| 125 |
+
if last_alignment is None:
|
| 126 |
+
last_alignment = decoder_hidden[0].new_zeros(batch_size, seq_length, 1)
|
| 127 |
+
|
| 128 |
+
loc_context = self.loc_conv(last_alignment.permute(0, 2, 1))
|
| 129 |
+
loc_context = loc_context.transpose(1, 2)
|
| 130 |
+
loc_context = self.loc_proj(loc_context)
|
| 131 |
+
|
| 132 |
+
assert len(loc_context.shape) == 3
|
| 133 |
+
assert loc_context.shape[0] == batch_size, f'{loc_context.shape[0]}-{batch_size}'
|
| 134 |
+
assert loc_context.shape[1] == seq_length
|
| 135 |
+
assert loc_context.shape[2] == hidden_dim
|
| 136 |
+
|
| 137 |
+
score = self.score(torch.tanh(
|
| 138 |
+
encoder_proj
|
| 139 |
+
+ hidden_proj
|
| 140 |
+
+ loc_context
|
| 141 |
+
))
|
| 142 |
+
return score
|
| 143 |
+
|
| 144 |
+
class CoverageAttention(nn.Module):
|
| 145 |
+
def __init__(self, input_dim, hidden_dim, kernel_size, kernel_dim, temperature=1.0, smoothing=False):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.smoothing = smoothing
|
| 148 |
+
self.temperature = temperature
|
| 149 |
+
self.prev_attn = None
|
| 150 |
+
self.attn = LocationAwareAttentionCell(kernel_size, kernel_dim, hidden_dim, input_dim)
|
| 151 |
+
|
| 152 |
+
def set_mem(self, prev_attn):
|
| 153 |
+
self.prev_attn = prev_attn
|
| 154 |
+
|
| 155 |
+
def reset_mem(self):
|
| 156 |
+
self.prev_attn = None
|
| 157 |
+
|
| 158 |
+
def forward(self, prev_hidden, batch_H):
|
| 159 |
+
e = self.attn(prev_hidden, batch_H, self.prev_attn)
|
| 160 |
+
|
| 161 |
+
if self.smoothing:
|
| 162 |
+
e = F.sigmoid(e, dim=1)
|
| 163 |
+
alpha = e.div(e.sum(dim=-1).unsqueeze(dim=-1))
|
| 164 |
+
else:
|
| 165 |
+
e = e / self.temperature
|
| 166 |
+
alpha = F.softmax(e, dim=1)
|
| 167 |
+
|
| 168 |
+
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
|
| 169 |
+
|
| 170 |
+
return context, alpha
|
| 171 |
+
|
| 172 |
+
class LocationAwareAttention(BahdanauAttention):
|
| 173 |
+
def __init__(self, kernel_size, kernel_dim, temperature=1.0, smoothing=False, *args, **kwargs):
|
| 174 |
+
super().__init__(*args, **kwargs)
|
| 175 |
+
self.smoothing = smoothing
|
| 176 |
+
self.temperature = temperature
|
| 177 |
+
self.prev_attn = None
|
| 178 |
+
self.attn = LocationAwareAttentionCell(kernel_size, kernel_dim, self.hidden_size, self.input_size)
|
| 179 |
+
|
| 180 |
+
def set_mem(self, prev_attn):
|
| 181 |
+
self.prev_attn = prev_attn
|
| 182 |
+
|
| 183 |
+
def reset_mem(self):
|
| 184 |
+
self.prev_attn = None
|
| 185 |
+
|
| 186 |
+
def forward(self, prev_hidden, batch_H, embed_text):
|
| 187 |
+
e = self.attn(prev_hidden, batch_H, self.prev_attn)
|
| 188 |
+
|
| 189 |
+
if self.smoothing:
|
| 190 |
+
e = F.sigmoid(e, dim=1)
|
| 191 |
+
alpha = e.div(e.sum(dim=-1).unsqueeze(dim=-1))
|
| 192 |
+
else:
|
| 193 |
+
e = e / self.temperature
|
| 194 |
+
alpha = F.softmax(e, dim=1)
|
| 195 |
+
|
| 196 |
+
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel , batch_H: batch_sizexseq_lengthxnum_channel, alpha:
|
| 197 |
+
concat_context = torch.cat([context, embed_text], 1) # batch_size x (num_channel + num_embedding)
|
| 198 |
+
cur_hidden = self.rnn(concat_context, prev_hidden)
|
| 199 |
+
output = self.generator(cur_hidden[0])
|
| 200 |
+
|
| 201 |
+
return output, cur_hidden, alpha
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# class MaskAttention(nn.Module):
|
| 205 |
+
# def __init__(self):
|
| 206 |
+
# super().__init__()
|
| 207 |
+
|
| 208 |
+
# class CoverageAttentionCell(nn.Module):
|
| 209 |
+
# def __init__(self, )
|
| 210 |
+
|
| 211 |
+
# class CoverageAttention(nn.Module):
|
| 212 |
+
# """
|
| 213 |
+
# http://home.ustc.edu.cn/~xysszjs/paper/PR2017.pdf
|
| 214 |
+
# """
|
| 215 |
+
# def __init__(self, input_size, hidden_size, num_embedding):
|
| 216 |
+
# super().__init__()
|
| 217 |
+
|
| 218 |
+
# def forward(self, prev_hidden, batch_H, embed_text):
|
HybridViT/module/component/prediction_head/addon_module/attention2D.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .attention1D import LocationAwareAttention
|
| 4 |
+
"""
|
| 5 |
+
NOTE :
|
| 6 |
+
"""
|
| 7 |
+
class SARAttention(nn.Module):
|
| 8 |
+
def __init__(self,
|
| 9 |
+
input_size,
|
| 10 |
+
attention_size,
|
| 11 |
+
backbone_size,
|
| 12 |
+
output_size,
|
| 13 |
+
|
| 14 |
+
):
|
| 15 |
+
self.conv1x1_1 = nn.Conv2d(output_size, attention_size, kernel_size=1, stride=1)
|
| 16 |
+
self.conv3x3 = nn.Conv2d(backbone_size, attention_size, kernel_size=3, stride=1, padding=1)
|
| 17 |
+
self.conv1x1_2 = nn.Conv2d(attention_size, 1, kernel_size=1, stride=1)
|
| 18 |
+
|
| 19 |
+
self.rnn_decoder_1 = nn.LSTMCell(input_size, input_size)
|
| 20 |
+
self.rnn_decoder_2 = nn.LSTMCell(input_size, input_size)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def forward(
|
| 24 |
+
self,
|
| 25 |
+
dec_input,
|
| 26 |
+
feature_map,
|
| 27 |
+
holistic_feature,
|
| 28 |
+
hidden_1,
|
| 29 |
+
cell_1,
|
| 30 |
+
hidden_2,
|
| 31 |
+
cell_2
|
| 32 |
+
):
|
| 33 |
+
_, _, H_feat, W_feat = feature_map.size()
|
| 34 |
+
hidden_1, cell_1 = self.rnn_decoder_1(dec_input, (hidden_1, cell_1))
|
| 35 |
+
hidden_2, cell_2 = self.rnn_decoder_2(hidden_1, (hidden_2, cell_2))
|
| 36 |
+
|
| 37 |
+
hidden_2_tile = hidden_2.view(hidden_2.size(0), hidden_2.size(1), 1, 1)
|
| 38 |
+
attn_query = self.conv1x1_1(hidden_2_tile)
|
| 39 |
+
attn_query = attn_query.expand(-1, -1, H_feat, W_feat)
|
| 40 |
+
|
| 41 |
+
attn_key = self.conv3x3(feature_map)
|
| 42 |
+
attn_weight = torch.tanh(torch.add(attn_query, attn_key, alpha=1))
|
| 43 |
+
attn_weight = self.conv1x1_2(attn_weight) #shape B, 1, H, W
|
| 44 |
+
|
| 45 |
+
#TO DO: apply mask for attention weight
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class LocationAwareAttentionCell2D(nn.Module):
|
| 49 |
+
def __init__(self, kernel_size, kernel_dim, hidden_dim, input_dim):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.loc_conv = nn.Conv2d(1, kernel_dim, kernel_size=2*kernel_size+1, padding=kernel_size, bias=True)
|
| 52 |
+
self.loc_proj = nn.Linear(kernel_dim, hidden_dim)
|
| 53 |
+
self.query_proj = nn.Linear(hidden_dim, hidden_dim)
|
| 54 |
+
self.key_proj = nn.Linear(input_dim, hidden_dim)
|
| 55 |
+
self.score = nn.Linear(hidden_dim, 1)
|
| 56 |
+
|
| 57 |
+
def forward(self, decoder_hidden, encoder_output, last_alignment):
|
| 58 |
+
batch_size, enc_h, enc_w, hidden_dim = encoder_output.shape[0], encoder_output.shape[1], encoder_output.shape[2], decoder_hidden[0].shape[1]
|
| 59 |
+
|
| 60 |
+
encoder_proj = self.key_proj(encoder_output)
|
| 61 |
+
hidden_proj = self.query_proj(decoder_hidden[0]).unsqueeze(1)
|
| 62 |
+
|
| 63 |
+
if last_alignment is None:
|
| 64 |
+
last_alignment = decoder_hidden[0].new_zeros(batch_size, enc_h, enc_w, 1)
|
| 65 |
+
|
| 66 |
+
loc_context = self.loc_conv(last_alignment.permute(0, 2, 1))
|
| 67 |
+
loc_context = loc_context.transpose(1, 2)
|
| 68 |
+
loc_context = self.loc_proj(loc_context)
|
| 69 |
+
|
| 70 |
+
assert len(loc_context.shape) == 3
|
| 71 |
+
assert loc_context.shape[0] == batch_size, f'{loc_context.shape[0]}-{batch_size}'
|
| 72 |
+
assert loc_context.shape[1] == enc_h
|
| 73 |
+
assert loc_context.shape[2] == enc_w
|
| 74 |
+
assert loc_context.shape[3] == hidden_dim
|
| 75 |
+
|
| 76 |
+
loc_context = loc_context.reshape(batch_size, enc_h*enc_w, hidden_dim)
|
| 77 |
+
|
| 78 |
+
score = self.score(torch.tanh(
|
| 79 |
+
encoder_proj
|
| 80 |
+
+ hidden_proj
|
| 81 |
+
+ loc_context
|
| 82 |
+
))
|
| 83 |
+
return score
|
| 84 |
+
|
| 85 |
+
class LocationAwareAttention2D(LocationAwareAttention):
|
| 86 |
+
def __init__(self, kernel_size, kernel_dim, temperature=1, smoothing=False, *args, **kwargs):
|
| 87 |
+
super().__init__(kernel_size, kernel_dim, temperature, smoothing, *args, **kwargs)
|
| 88 |
+
self.attn = LocationAwareAttentionCell2D(kernel_size, kernel_dim, self.hidden_size, self.input_size)
|
HybridViT/module/component/prediction_head/addon_module/position_encoding.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
__all__ = ['WordPosEnc']
|
| 5 |
+
|
| 6 |
+
class WordPosEnc(nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self, d_model: int = 512, max_len: int = 500, temperature: float = 10000.0
|
| 9 |
+
) -> None:
|
| 10 |
+
super().__init__()
|
| 11 |
+
pe = torch.zeros(max_len, d_model)
|
| 12 |
+
|
| 13 |
+
position = torch.arange(0, max_len, dtype=torch.float)
|
| 14 |
+
dim_t = torch.arange(0, d_model, 2, dtype=torch.float)
|
| 15 |
+
div_term = 1.0 / (temperature ** (dim_t / d_model))
|
| 16 |
+
|
| 17 |
+
inv_freq = torch.einsum("i, j -> i j", position, div_term)
|
| 18 |
+
|
| 19 |
+
pe[:, 0::2] = inv_freq.sin()
|
| 20 |
+
pe[:, 1::2] = inv_freq.cos()
|
| 21 |
+
self.register_buffer("pe", pe)
|
| 22 |
+
|
| 23 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
_, seq_len, _ = x.size()
|
| 25 |
+
emb = self.pe[:seq_len, :]
|
| 26 |
+
x = x + emb[None, :, :]
|
| 27 |
+
return x
|
HybridViT/module/component/prediction_head/seq2seq.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import repeat
|
| 6 |
+
from ...converter import AttnLabelConverter as ATTN
|
| 7 |
+
from .addon_module import *
|
| 8 |
+
|
| 9 |
+
class Attention(nn.Module):
|
| 10 |
+
def __init__(self,
|
| 11 |
+
kernel_size,
|
| 12 |
+
kernel_dim,
|
| 13 |
+
input_size,
|
| 14 |
+
hidden_size,
|
| 15 |
+
num_classes,
|
| 16 |
+
embed_dim=None,
|
| 17 |
+
attn_type='coverage',
|
| 18 |
+
embed_target=False,
|
| 19 |
+
enc_init=False, #init hidden state of decoder with enc output
|
| 20 |
+
teacher_forcing=1.0,
|
| 21 |
+
droprate=0.1,
|
| 22 |
+
method='concat',
|
| 23 |
+
seqmodel='ViT',
|
| 24 |
+
viz_attn: bool = False,
|
| 25 |
+
device='cuda'
|
| 26 |
+
):
|
| 27 |
+
super(Attention, self).__init__()
|
| 28 |
+
if embed_dim is None: embed_dim = input_size
|
| 29 |
+
if embed_target:
|
| 30 |
+
self.embedding = nn.Embedding(num_classes, embed_dim, padding_idx=ATTN.START())
|
| 31 |
+
|
| 32 |
+
common = {
|
| 33 |
+
'input_size': input_size,
|
| 34 |
+
'hidden_size': hidden_size,
|
| 35 |
+
'num_embeddings': embed_dim if embed_target else num_classes,
|
| 36 |
+
'num_classes': num_classes
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
if attn_type == 'luong':
|
| 40 |
+
common['method'] = method
|
| 41 |
+
self.attention_cell = LuongAttention(**common)
|
| 42 |
+
elif attn_type == 'loc_aware':
|
| 43 |
+
self.attention_cell = LocationAwareAttention(kernel_size=kernel_size, kernel_dim=kernel_dim, **common)
|
| 44 |
+
elif attn_type == 'coverage':
|
| 45 |
+
self.attention_cell = LocationAwareAttention(kernel_size=kernel_size, kernel_dim=kernel_dim, **common)
|
| 46 |
+
else:
|
| 47 |
+
self.attention_cell = BahdanauAttention(**common)
|
| 48 |
+
|
| 49 |
+
self.dropout = nn.Dropout(droprate)
|
| 50 |
+
self.embed_target = embed_target
|
| 51 |
+
self.hidden_size = hidden_size
|
| 52 |
+
self.input_size = input_size
|
| 53 |
+
self.num_classes = num_classes
|
| 54 |
+
self.teacher_forcing = teacher_forcing
|
| 55 |
+
self.device = device
|
| 56 |
+
self.attn_type = attn_type
|
| 57 |
+
self.enc_init = enc_init
|
| 58 |
+
self.viz_attn = viz_attn
|
| 59 |
+
self.seqmodel = seqmodel
|
| 60 |
+
|
| 61 |
+
if enc_init: self.init_hidden()
|
| 62 |
+
|
| 63 |
+
def _embed_text(self, input_char):
|
| 64 |
+
return self.embedding(input_char)
|
| 65 |
+
|
| 66 |
+
def _char_to_onehot(self, input_char, onehot_dim=38):
|
| 67 |
+
input_char = input_char.unsqueeze(1)
|
| 68 |
+
batch_size = input_char.size(0)
|
| 69 |
+
one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(self.device)
|
| 70 |
+
one_hot = one_hot.scatter_(1, input_char, 1)
|
| 71 |
+
return one_hot
|
| 72 |
+
|
| 73 |
+
def init_hidden(self):
|
| 74 |
+
self.proj_init_h = nn.Linear(self.input_size, self.hidden_size, bias=True)
|
| 75 |
+
self.proj_init_c = nn.Linear(self.input_size, self.hidden_size, bias=True)
|
| 76 |
+
|
| 77 |
+
def forward_beam(
|
| 78 |
+
self,
|
| 79 |
+
batch_H: torch.Tensor,
|
| 80 |
+
batch_max_length=25,
|
| 81 |
+
beam_size=4,
|
| 82 |
+
):
|
| 83 |
+
batch_size = batch_H.size(0)
|
| 84 |
+
assert batch_size == 1
|
| 85 |
+
num_steps = batch_max_length + 1
|
| 86 |
+
batch_H = batch_H.squeeze(dim=0)
|
| 87 |
+
batch_H = repeat(batch_H, "s e -> b s e", b = beam_size)
|
| 88 |
+
|
| 89 |
+
if self.enc_init:
|
| 90 |
+
if self.seqmodel == 'BiLSTM':
|
| 91 |
+
init_embedding = batch_H.mean(dim=1)
|
| 92 |
+
else:
|
| 93 |
+
init_embedding = batch_H[:, 0, :]
|
| 94 |
+
h_0 = self.proj_init_h(init_embedding)
|
| 95 |
+
c_0 = self.proj_init_c(init_embedding)
|
| 96 |
+
hidden = (h_0, c_0)
|
| 97 |
+
else:
|
| 98 |
+
hidden = (torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device),
|
| 99 |
+
torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device))
|
| 100 |
+
|
| 101 |
+
if self.attn_type == 'coverage':
|
| 102 |
+
alpha_cum = torch.zeros(beam_size, batch_H.shape[1], 1, dtype=torch.float32, device=self.device)
|
| 103 |
+
self.attention_cell.reset_mem()
|
| 104 |
+
|
| 105 |
+
k_prev_words = torch.LongTensor([[ATTN.START()]] * beam_size).to(self.device)
|
| 106 |
+
seqs = k_prev_words
|
| 107 |
+
targets = k_prev_words.squeeze(dim=-1)
|
| 108 |
+
top_k_scores = torch.zeros(beam_size, 1).to(self.device)
|
| 109 |
+
|
| 110 |
+
if self.viz_attn:
|
| 111 |
+
seqs_alpha = torch.ones(beam_size, 1, batch_H.shape[1]).to(self.device)
|
| 112 |
+
|
| 113 |
+
complete_seqs = list()
|
| 114 |
+
if self.viz_attn:
|
| 115 |
+
complete_seqs_alpha = list()
|
| 116 |
+
complete_seqs_scores = list()
|
| 117 |
+
|
| 118 |
+
for step in range(num_steps):
|
| 119 |
+
embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets)
|
| 120 |
+
output, hidden, alpha = self.attention_cell(hidden, batch_H, embed_text)
|
| 121 |
+
output = self.dropout(output)
|
| 122 |
+
vocab_size = output.shape[1]
|
| 123 |
+
|
| 124 |
+
scores = F.log_softmax(output, dim=-1)
|
| 125 |
+
scores = top_k_scores.expand_as(scores) + scores
|
| 126 |
+
if step == 0:
|
| 127 |
+
top_k_scores, top_k_words = scores[0].topk(beam_size, 0, True, True)
|
| 128 |
+
else:
|
| 129 |
+
top_k_scores, top_k_words = scores.view(-1).topk(beam_size, 0, True, True)
|
| 130 |
+
|
| 131 |
+
prev_word_inds = top_k_words // vocab_size
|
| 132 |
+
next_word_inds = top_k_words % vocab_size
|
| 133 |
+
|
| 134 |
+
seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)
|
| 135 |
+
if self.viz_attn:
|
| 136 |
+
seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].permute(0, 2, 1)],
|
| 137 |
+
dim=1)
|
| 138 |
+
|
| 139 |
+
incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
|
| 140 |
+
next_word != ATTN.END()]
|
| 141 |
+
|
| 142 |
+
complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
|
| 143 |
+
|
| 144 |
+
if len(complete_inds) > 0:
|
| 145 |
+
complete_seqs.extend(seqs[complete_inds].tolist())
|
| 146 |
+
if self.viz_attn:
|
| 147 |
+
complete_seqs_alpha.extend(seqs_alpha[complete_inds])
|
| 148 |
+
complete_seqs_scores.extend(top_k_scores[complete_inds])
|
| 149 |
+
|
| 150 |
+
beam_size = beam_size - len(complete_inds)
|
| 151 |
+
if beam_size == 0:
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
seqs = seqs[incomplete_inds]
|
| 155 |
+
if self.viz_attn:
|
| 156 |
+
seqs_alpha = seqs_alpha[incomplete_inds]
|
| 157 |
+
hidden = hidden[0][prev_word_inds[incomplete_inds]], \
|
| 158 |
+
hidden[1][prev_word_inds[incomplete_inds]]
|
| 159 |
+
batch_H = batch_H[prev_word_inds[incomplete_inds]]
|
| 160 |
+
top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
|
| 161 |
+
targets = next_word_inds[incomplete_inds]
|
| 162 |
+
|
| 163 |
+
if self.attn_type == 'coverage':
|
| 164 |
+
alpha_cum = alpha_cum + alpha
|
| 165 |
+
alpha_cum = alpha_cum[incomplete_inds]
|
| 166 |
+
self.attention_cell.set_mem(alpha_cum)
|
| 167 |
+
elif self.attn_type == 'loc_aware':
|
| 168 |
+
self.attention_cell.set_mem(alpha)
|
| 169 |
+
|
| 170 |
+
if len(complete_inds) == 0:
|
| 171 |
+
seq = seqs[0][1:].tolist()
|
| 172 |
+
seq = torch.LongTensor(seq).unsqueeze(0)
|
| 173 |
+
score = top_k_scores[0]
|
| 174 |
+
if self.viz_attn:
|
| 175 |
+
alphas = seqs_alpha[0][1:, ...]
|
| 176 |
+
return seq, score, alphas
|
| 177 |
+
else:
|
| 178 |
+
return seq, score, None
|
| 179 |
+
else:
|
| 180 |
+
combine_lst = tuple(zip(complete_seqs, complete_seqs_scores))
|
| 181 |
+
best_ind = combine_lst.index(max(combine_lst, key=lambda x: x[1] / len(x[0]))) #https://youtu.be/XXtpJxZBa2c?t=2407
|
| 182 |
+
seq = complete_seqs[best_ind][1:] #not include [GO] token
|
| 183 |
+
seq = torch.LongTensor(seq).unsqueeze(0)
|
| 184 |
+
score = max(complete_seqs_scores)
|
| 185 |
+
|
| 186 |
+
if self.viz_attn:
|
| 187 |
+
alphas = complete_seqs_alpha[best_ind][1:, ...]
|
| 188 |
+
return seq, score, alphas
|
| 189 |
+
else:
|
| 190 |
+
return seq, score, None
|
| 191 |
+
|
| 192 |
+
def forward_greedy(self, batch_H, text, is_train=True, is_test=False, batch_max_length=25):
|
| 193 |
+
batch_size = batch_H.size(0)
|
| 194 |
+
num_steps = batch_max_length + 1
|
| 195 |
+
if self.enc_init:
|
| 196 |
+
if self.seqmodel == 'BiLSTM':
|
| 197 |
+
init_embedding = batch_H.mean(dim=1)
|
| 198 |
+
encoder_hidden = batch_H
|
| 199 |
+
else:
|
| 200 |
+
encoder_hidden = batch_H
|
| 201 |
+
init_embedding = batch_H[:, 0, :]
|
| 202 |
+
h_0 = self.proj_init_h(init_embedding)
|
| 203 |
+
c_0 = self.proj_init_c(init_embedding)
|
| 204 |
+
hidden = (h_0, c_0)
|
| 205 |
+
else:
|
| 206 |
+
encoder_hidden = batch_H
|
| 207 |
+
hidden = (torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device),
|
| 208 |
+
torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device))
|
| 209 |
+
|
| 210 |
+
targets = torch.zeros(batch_size, dtype=torch.long, device=self.device) # [GO] token
|
| 211 |
+
probs = torch.zeros(batch_size, num_steps, self.num_classes, dtype=torch.float32, device=self.device)
|
| 212 |
+
|
| 213 |
+
if self.viz_attn:
|
| 214 |
+
self.alpha_stores = torch.zeros(batch_size, num_steps, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
|
| 215 |
+
if self.attn_type == 'coverage':
|
| 216 |
+
alpha_cum = torch.zeros(batch_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
|
| 217 |
+
|
| 218 |
+
self.attention_cell.reset_mem()
|
| 219 |
+
|
| 220 |
+
if is_test:
|
| 221 |
+
end_flag = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
|
| 222 |
+
|
| 223 |
+
for i in range(num_steps):
|
| 224 |
+
embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets)
|
| 225 |
+
output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text)
|
| 226 |
+
output = self.dropout(output)
|
| 227 |
+
if self.viz_attn:
|
| 228 |
+
self.alpha_stores[:, i] = alpha
|
| 229 |
+
if self.attn_type == 'coverage':
|
| 230 |
+
alpha_cum = alpha_cum + alpha
|
| 231 |
+
self.attention_cell.set_mem(alpha_cum)
|
| 232 |
+
elif self.attn_type == 'loc_aware':
|
| 233 |
+
self.attention_cell.set_mem(alpha)
|
| 234 |
+
|
| 235 |
+
probs_step = output
|
| 236 |
+
probs[:, i, :] = probs_step
|
| 237 |
+
|
| 238 |
+
if i == num_steps - 1:
|
| 239 |
+
break
|
| 240 |
+
|
| 241 |
+
if is_train:
|
| 242 |
+
if self.teacher_forcing < random.random():
|
| 243 |
+
_, next_input = probs_step.max(1)
|
| 244 |
+
targets = next_input
|
| 245 |
+
else:
|
| 246 |
+
targets = text[:, i+1]
|
| 247 |
+
else:
|
| 248 |
+
_, next_input = probs_step.max(1)
|
| 249 |
+
targets = next_input
|
| 250 |
+
|
| 251 |
+
if is_test:
|
| 252 |
+
end_flag = end_flag | (next_input == ATTN.END())
|
| 253 |
+
if end_flag.all():
|
| 254 |
+
break
|
| 255 |
+
|
| 256 |
+
_, preds_index = probs.max(2)
|
| 257 |
+
|
| 258 |
+
return preds_index, probs, None # batch_size x num_steps x num_classes
|
| 259 |
+
|
| 260 |
+
def forward(self, beam_size, batch_H, text, batch_max_length, is_train=True, is_test=False):
|
| 261 |
+
if is_train:
|
| 262 |
+
return self.forward_greedy(batch_H, text, is_train, is_test, batch_max_length)
|
| 263 |
+
else:
|
| 264 |
+
if beam_size > 1:
|
| 265 |
+
return self.forward_beam(batch_H, batch_max_length, beam_size)
|
| 266 |
+
else:
|
| 267 |
+
return self.forward_greedy(batch_H, text, is_train, is_test, batch_max_length)
|
| 268 |
+
|
HybridViT/module/component/prediction_head/seq2seq_v2.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import repeat
|
| 6 |
+
from ...converter import AttnLabelConverter as ATTN
|
| 7 |
+
from .addon_module import *
|
| 8 |
+
from .seq2seq import Attention
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AttentionV2(Attention):
|
| 12 |
+
def forward_beam(
|
| 13 |
+
self,
|
| 14 |
+
batch_H: torch.Tensor,
|
| 15 |
+
batch_max_length=25,
|
| 16 |
+
beam_size=4,
|
| 17 |
+
):
|
| 18 |
+
batch_size = batch_H.size(0)
|
| 19 |
+
assert batch_size == 1
|
| 20 |
+
num_steps = batch_max_length + 1
|
| 21 |
+
batch_H = batch_H.squeeze(dim=0)
|
| 22 |
+
batch_H = repeat(batch_H, "s e -> b s e", b = beam_size)
|
| 23 |
+
|
| 24 |
+
encoder_hidden = None
|
| 25 |
+
if self.seqmodel in ['BiLSTM', 'VIG']:
|
| 26 |
+
encoder_hidden = batch_H
|
| 27 |
+
elif self.seqmodel == 'TFM':
|
| 28 |
+
encoder_hidden = batch_H[:, 1:, :]
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError('seqmodel must be either BiLSTM or TFM option')
|
| 31 |
+
|
| 32 |
+
if self.enc_init:
|
| 33 |
+
init_embedding = None
|
| 34 |
+
if self.seqmodel in ['BiLSTM', 'VIG']:
|
| 35 |
+
init_embedding = batch_H.mean(dim=1)
|
| 36 |
+
elif self.seqmodel == 'TFM':
|
| 37 |
+
init_embedding = batch_H[:, 0, :]
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError('seqmodel must be either BiLSTM or TFM option')
|
| 40 |
+
|
| 41 |
+
assert init_embedding is not None
|
| 42 |
+
h_0 = self.proj_init_h(init_embedding)
|
| 43 |
+
c_0 = self.proj_init_c(init_embedding)
|
| 44 |
+
hidden = (h_0, c_0)
|
| 45 |
+
else:
|
| 46 |
+
hidden = (torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device),
|
| 47 |
+
torch.zeros(beam_size, self.hidden_size, dtype=torch.float32, device=self.device))
|
| 48 |
+
|
| 49 |
+
assert encoder_hidden is not None
|
| 50 |
+
|
| 51 |
+
if self.attn_type == 'coverage':
|
| 52 |
+
alpha_cum = torch.zeros(beam_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
|
| 53 |
+
self.attention_cell.reset_mem()
|
| 54 |
+
|
| 55 |
+
k_prev_words = torch.LongTensor([[ATTN.START()]] * beam_size).to(self.device)
|
| 56 |
+
seqs = k_prev_words
|
| 57 |
+
targets = k_prev_words.squeeze(dim=-1)
|
| 58 |
+
top_k_scores = torch.zeros(beam_size, 1).to(self.device)
|
| 59 |
+
|
| 60 |
+
if self.viz_attn:
|
| 61 |
+
seqs_alpha = torch.ones(beam_size, 1, encoder_hidden.shape[1]).to(self.device)
|
| 62 |
+
|
| 63 |
+
complete_seqs = list()
|
| 64 |
+
if self.viz_attn:
|
| 65 |
+
complete_seqs_alpha = list()
|
| 66 |
+
complete_seqs_scores = list()
|
| 67 |
+
|
| 68 |
+
for step in range(num_steps):
|
| 69 |
+
embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets)
|
| 70 |
+
output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text)
|
| 71 |
+
output = self.dropout(output)
|
| 72 |
+
vocab_size = output.shape[1]
|
| 73 |
+
|
| 74 |
+
scores = F.log_softmax(output, dim=-1)
|
| 75 |
+
scores = top_k_scores.expand_as(scores) + scores
|
| 76 |
+
if step == 0:
|
| 77 |
+
top_k_scores, top_k_words = scores[0].topk(beam_size, 0, True, True)
|
| 78 |
+
else:
|
| 79 |
+
top_k_scores, top_k_words = scores.view(-1).topk(beam_size, 0, True, True)
|
| 80 |
+
|
| 81 |
+
prev_word_inds = top_k_words // vocab_size
|
| 82 |
+
next_word_inds = top_k_words % vocab_size
|
| 83 |
+
|
| 84 |
+
seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)
|
| 85 |
+
if self.viz_attn:
|
| 86 |
+
seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].permute(0, 2, 1)],
|
| 87 |
+
dim=1)
|
| 88 |
+
|
| 89 |
+
incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
|
| 90 |
+
next_word != ATTN.END()]
|
| 91 |
+
|
| 92 |
+
complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
|
| 93 |
+
|
| 94 |
+
if len(complete_inds) > 0:
|
| 95 |
+
complete_seqs.extend(seqs[complete_inds].tolist())
|
| 96 |
+
if self.viz_attn:
|
| 97 |
+
complete_seqs_alpha.extend(seqs_alpha[complete_inds])
|
| 98 |
+
complete_seqs_scores.extend(top_k_scores[complete_inds])
|
| 99 |
+
|
| 100 |
+
beam_size = beam_size - len(complete_inds)
|
| 101 |
+
if beam_size == 0:
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
+
seqs = seqs[incomplete_inds]
|
| 105 |
+
if self.viz_attn:
|
| 106 |
+
seqs_alpha = seqs_alpha[incomplete_inds]
|
| 107 |
+
hidden = hidden[0][prev_word_inds[incomplete_inds]], \
|
| 108 |
+
hidden[1][prev_word_inds[incomplete_inds]]
|
| 109 |
+
encoder_hidden = encoder_hidden[prev_word_inds[incomplete_inds]]
|
| 110 |
+
top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
|
| 111 |
+
targets = next_word_inds[incomplete_inds]
|
| 112 |
+
|
| 113 |
+
if self.attn_type == 'coverage':
|
| 114 |
+
alpha_cum = alpha_cum + alpha
|
| 115 |
+
alpha_cum = alpha_cum[incomplete_inds]
|
| 116 |
+
self.attention_cell.set_mem(alpha_cum)
|
| 117 |
+
elif self.attn_type == 'loc_aware':
|
| 118 |
+
self.attention_cell.set_mem(alpha)
|
| 119 |
+
|
| 120 |
+
if len(complete_inds) == 0:
|
| 121 |
+
seq = seqs[0][1:].tolist()
|
| 122 |
+
seq = torch.LongTensor(seq).unsqueeze(0)
|
| 123 |
+
score = top_k_scores[0]
|
| 124 |
+
if self.viz_attn:
|
| 125 |
+
alphas = seqs_alpha[0][1:, ...]
|
| 126 |
+
return seq, score, alphas
|
| 127 |
+
else:
|
| 128 |
+
return seq, score, None
|
| 129 |
+
else:
|
| 130 |
+
combine_lst = tuple(zip(complete_seqs, complete_seqs_scores))
|
| 131 |
+
best_ind = combine_lst.index(max(combine_lst, key=lambda x: x[1] / len(x[0]))) #https://youtu.be/XXtpJxZBa2c?t=2407
|
| 132 |
+
seq = complete_seqs[best_ind][1:] #not include [GO] token
|
| 133 |
+
seq = torch.LongTensor(seq).unsqueeze(0)
|
| 134 |
+
score = max(complete_seqs_scores)
|
| 135 |
+
|
| 136 |
+
if self.viz_attn:
|
| 137 |
+
alphas = complete_seqs_alpha[best_ind][1:, ...]
|
| 138 |
+
return seq, score, alphas
|
| 139 |
+
else:
|
| 140 |
+
return seq, score, None
|
| 141 |
+
|
| 142 |
+
def forward_greedy(self, batch_H, text, is_train=True, is_test=False, batch_max_length=25):
|
| 143 |
+
batch_size = batch_H.size(0)
|
| 144 |
+
num_steps = batch_max_length + 1
|
| 145 |
+
encoder_hidden = None
|
| 146 |
+
if self.seqmodel in ['BiLSTM', 'VIG']:
|
| 147 |
+
encoder_hidden = batch_H
|
| 148 |
+
elif self.seqmodel == 'TFM':
|
| 149 |
+
encoder_hidden = batch_H[:, 1:, :]
|
| 150 |
+
else:
|
| 151 |
+
raise ValueError('seqmodel must be either BiLSTM or TFM option')
|
| 152 |
+
|
| 153 |
+
if self.enc_init:
|
| 154 |
+
init_embedding = None
|
| 155 |
+
if self.seqmodel in ['BiLSTM', 'VIG']:
|
| 156 |
+
init_embedding = batch_H.mean(dim=1)
|
| 157 |
+
elif self.seqmodel == 'TFM':
|
| 158 |
+
init_embedding = batch_H[:, 0, :]
|
| 159 |
+
else:
|
| 160 |
+
raise ValueError('seqmodel must be either BiLSTM or TFM option')
|
| 161 |
+
h_0 = self.proj_init_h(init_embedding)
|
| 162 |
+
c_0 = self.proj_init_c(init_embedding)
|
| 163 |
+
hidden = (h_0, c_0)
|
| 164 |
+
else:
|
| 165 |
+
hidden = (torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device),
|
| 166 |
+
torch.zeros(batch_size, self.hidden_size, dtype=torch.float32, device=self.device))
|
| 167 |
+
|
| 168 |
+
targets = torch.zeros(batch_size, dtype=torch.long, device=self.device) # [GO] token
|
| 169 |
+
probs = torch.zeros(batch_size, num_steps, self.num_classes, dtype=torch.float32, device=self.device)
|
| 170 |
+
|
| 171 |
+
assert encoder_hidden is not None
|
| 172 |
+
|
| 173 |
+
if self.viz_attn:
|
| 174 |
+
self.alpha_stores = torch.zeros(batch_size, num_steps, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
|
| 175 |
+
if self.attn_type == 'coverage':
|
| 176 |
+
alpha_cum = torch.zeros(batch_size, encoder_hidden.shape[1], 1, dtype=torch.float32, device=self.device)
|
| 177 |
+
|
| 178 |
+
self.attention_cell.reset_mem()
|
| 179 |
+
|
| 180 |
+
if is_test:
|
| 181 |
+
end_flag = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
|
| 182 |
+
|
| 183 |
+
for i in range(num_steps):
|
| 184 |
+
embed_text = self._char_to_onehot(targets, onehot_dim=self.num_classes) if not self.embed_target else self._embed_text(targets)
|
| 185 |
+
output, hidden, alpha = self.attention_cell(hidden, encoder_hidden, embed_text)
|
| 186 |
+
output = self.dropout(output)
|
| 187 |
+
if self.viz_attn:
|
| 188 |
+
self.alpha_stores[:, i] = alpha
|
| 189 |
+
if self.attn_type == 'coverage':
|
| 190 |
+
alpha_cum = alpha_cum + alpha
|
| 191 |
+
self.attention_cell.set_mem(alpha_cum)
|
| 192 |
+
elif self.attn_type == 'loc_aware':
|
| 193 |
+
self.attention_cell.set_mem(alpha)
|
| 194 |
+
|
| 195 |
+
probs_step = output
|
| 196 |
+
probs[:, i, :] = probs_step
|
| 197 |
+
|
| 198 |
+
if i == num_steps - 1:
|
| 199 |
+
break
|
| 200 |
+
|
| 201 |
+
if is_train:
|
| 202 |
+
if self.teacher_forcing < random.random():
|
| 203 |
+
_, next_input = probs_step.max(1)
|
| 204 |
+
targets = next_input
|
| 205 |
+
else:
|
| 206 |
+
targets = text[:, i+1]
|
| 207 |
+
else:
|
| 208 |
+
_, next_input = probs_step.max(1)
|
| 209 |
+
targets = next_input
|
| 210 |
+
|
| 211 |
+
if is_test:
|
| 212 |
+
end_flag = end_flag | (next_input == ATTN.END())
|
| 213 |
+
if end_flag.all():
|
| 214 |
+
break
|
| 215 |
+
|
| 216 |
+
_, preds_index = probs.max(2)
|
| 217 |
+
|
| 218 |
+
return preds_index, probs, None # batch_size x num_steps x num_classes
|
HybridViT/module/component/prediction_head/tfm.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange, repeat
|
| 6 |
+
from torch import FloatTensor, LongTensor
|
| 7 |
+
from .addon_module import WordPosEnc
|
| 8 |
+
from ...converter.tfm_converter import TFMLabelConverter as TFM
|
| 9 |
+
from ....beam import Beam
|
| 10 |
+
|
| 11 |
+
def _build_transformer_decoder(
|
| 12 |
+
d_model: int,
|
| 13 |
+
nhead: int,
|
| 14 |
+
num_decoder_layers: int,
|
| 15 |
+
dim_feedforward: int,
|
| 16 |
+
dropout: float,
|
| 17 |
+
) -> nn.TransformerDecoder:
|
| 18 |
+
decoder_layer = nn.TransformerDecoderLayer(
|
| 19 |
+
d_model=d_model,
|
| 20 |
+
nhead=nhead,
|
| 21 |
+
dim_feedforward=dim_feedforward,
|
| 22 |
+
dropout=dropout,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers)
|
| 26 |
+
|
| 27 |
+
for p in decoder.parameters():
|
| 28 |
+
if p.dim() > 1:
|
| 29 |
+
nn.init.xavier_uniform_(p)
|
| 30 |
+
|
| 31 |
+
return decoder
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class TransformerPrediction(nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
d_model: int,
|
| 38 |
+
nhead: int,
|
| 39 |
+
num_decoder_layers: int,
|
| 40 |
+
dim_feedforward: int,
|
| 41 |
+
dropout: float,
|
| 42 |
+
num_classes: int,
|
| 43 |
+
max_seq_len: int,
|
| 44 |
+
padding_idx: int,
|
| 45 |
+
device: str = 'cuda:1'
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.max_seq_len = max_seq_len
|
| 49 |
+
self.padding_idx = padding_idx
|
| 50 |
+
self.num_classes = num_classes
|
| 51 |
+
self.device = device
|
| 52 |
+
self.word_embed = nn.Embedding(
|
| 53 |
+
num_classes, d_model, padding_idx=padding_idx
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.pos_enc = WordPosEnc(d_model=d_model)
|
| 57 |
+
self.d_model = d_model
|
| 58 |
+
self.model = _build_transformer_decoder(
|
| 59 |
+
d_model=d_model,
|
| 60 |
+
nhead=nhead,
|
| 61 |
+
num_decoder_layers=num_decoder_layers,
|
| 62 |
+
dim_feedforward=dim_feedforward,
|
| 63 |
+
dropout=dropout,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.proj = nn.Linear(d_model, num_classes)
|
| 67 |
+
self.beam = Beam(
|
| 68 |
+
ignore_w=TFM.PAD(),
|
| 69 |
+
start_w=TFM.START(),
|
| 70 |
+
stop_w=TFM.END(),
|
| 71 |
+
max_len=self.max_seq_len,
|
| 72 |
+
device=self.device
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def reset_beam(self):
|
| 76 |
+
self.beam = Beam(
|
| 77 |
+
ignore_w=TFM.PAD(),
|
| 78 |
+
start_w=TFM.START(),
|
| 79 |
+
stop_w=TFM.END(),
|
| 80 |
+
max_len=self.max_seq_len,
|
| 81 |
+
device=self.device
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def _build_attention_mask(self, length):
|
| 85 |
+
mask = torch.full(
|
| 86 |
+
(length, length),
|
| 87 |
+
fill_value=1,
|
| 88 |
+
dtype=torch.bool,
|
| 89 |
+
device=self.device
|
| 90 |
+
)
|
| 91 |
+
mask = torch.triu(mask).transpose(0, 1)
|
| 92 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
| 93 |
+
return mask
|
| 94 |
+
|
| 95 |
+
def _embedd_tgt(self, tgt: LongTensor, tgt_len: int):
|
| 96 |
+
tgt_mask = self._build_attention_mask(tgt_len)
|
| 97 |
+
if self.training:
|
| 98 |
+
tgt_pad_mask = tgt == self.padding_idx
|
| 99 |
+
else: tgt_pad_mask = None
|
| 100 |
+
tgt = self.word_embed(tgt)
|
| 101 |
+
tgt = self.pos_enc(tgt*math.sqrt(self.d_model))
|
| 102 |
+
return tgt, tgt_mask, tgt_pad_mask
|
| 103 |
+
|
| 104 |
+
def forward_greedy(
|
| 105 |
+
self, src: FloatTensor, tgt: LongTensor, output_weight: bool = False, is_test: bool = False
|
| 106 |
+
) -> FloatTensor:
|
| 107 |
+
if self.training:
|
| 108 |
+
_, l = tgt.size()
|
| 109 |
+
tgt, tgt_mask, tgt_pad_mask = self._embedd_tgt(tgt, l)
|
| 110 |
+
|
| 111 |
+
src = rearrange(src, "b t d -> t b d")
|
| 112 |
+
tgt = rearrange(tgt, "b l d -> l b d")
|
| 113 |
+
|
| 114 |
+
out = self.model(
|
| 115 |
+
tgt=tgt,
|
| 116 |
+
memory=src,
|
| 117 |
+
tgt_mask=tgt_mask,
|
| 118 |
+
tgt_key_padding_mask=tgt_pad_mask
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
out = rearrange(out, "l b d -> b l d")
|
| 122 |
+
out = self.proj(out)
|
| 123 |
+
else:
|
| 124 |
+
out = None
|
| 125 |
+
src = rearrange(src, "b t d -> t b d")
|
| 126 |
+
|
| 127 |
+
end_flag = torch.zeros(src.shape[0], dtype=torch.bool, device=self.device)
|
| 128 |
+
|
| 129 |
+
for step in range(self.max_seq_len+1):
|
| 130 |
+
b, l = tgt.size()
|
| 131 |
+
emb_tgt, tgt_mask, tgt_pad_mask = self._embedd_tgt(tgt, l)
|
| 132 |
+
emb_tgt = rearrange(emb_tgt, "b l d -> l b d")
|
| 133 |
+
|
| 134 |
+
out = self.model(
|
| 135 |
+
tgt=emb_tgt,
|
| 136 |
+
memory=src,
|
| 137 |
+
tgt_mask=tgt_mask
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
out = rearrange(out, "l b d -> b l d")
|
| 141 |
+
out = self.proj(out)
|
| 142 |
+
probs = F.softmax(out, dim=-1)
|
| 143 |
+
next_text = torch.argmax(probs[:, -1:, :], dim=-1)
|
| 144 |
+
tgt = torch.cat([tgt, next_text], dim=-1)
|
| 145 |
+
|
| 146 |
+
end_flag = end_flag | (next_text == TFM.END())
|
| 147 |
+
if end_flag.all() and is_test:
|
| 148 |
+
break
|
| 149 |
+
|
| 150 |
+
_, preds_index = out.max(dim=2)
|
| 151 |
+
return preds_index, out
|
| 152 |
+
|
| 153 |
+
def forward_beam(self,
|
| 154 |
+
src: torch.FloatTensor,
|
| 155 |
+
beam_size: int
|
| 156 |
+
):
|
| 157 |
+
assert src.size(0) == 1, f'beam search should only have signle source, encounter with batch size: {src.size(0)}'
|
| 158 |
+
out = None
|
| 159 |
+
src = src.squeeze(0)
|
| 160 |
+
|
| 161 |
+
for step in range(self.max_seq_len + 1):
|
| 162 |
+
hypotheses = self.beam.hypotheses
|
| 163 |
+
hyp_num = hypotheses.size(0)
|
| 164 |
+
l = hypotheses.size(1)
|
| 165 |
+
assert hyp_num <= beam_size, f"hyp_num: {hyp_num}, beam_size: {beam_size}"
|
| 166 |
+
|
| 167 |
+
emb_tgt = self.word_embed(hypotheses)
|
| 168 |
+
emb_tgt = self.pos_enc(emb_tgt*math.sqrt(self.d_model))
|
| 169 |
+
tgt_mask = self._build_attention_mask(l)
|
| 170 |
+
emb_tgt = rearrange(emb_tgt, "b l d -> l b d")
|
| 171 |
+
|
| 172 |
+
exp_src = repeat(src.squeeze(1), "s e -> s b e", b=hyp_num)
|
| 173 |
+
|
| 174 |
+
out = self.model(
|
| 175 |
+
tgt=emb_tgt,
|
| 176 |
+
memory=exp_src,
|
| 177 |
+
tgt_mask=tgt_mask
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
out = rearrange(out, "l b d -> b l d")
|
| 181 |
+
out = self.proj(out)
|
| 182 |
+
log_prob = F.log_softmax(out[:, step, :], dim=-1)
|
| 183 |
+
new_hypotheses, new_hyp_scores = self.beam.advance(log_prob, step, beam_size=beam_size)
|
| 184 |
+
|
| 185 |
+
if self.beam.done(beam_size):
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
self.beam.set_current_state(new_hypotheses)
|
| 189 |
+
self.beam.set_current_score(new_hyp_scores)
|
| 190 |
+
|
| 191 |
+
self.beam.set_hypothesis()
|
| 192 |
+
best_hyp = max(self.beam.completed_hypotheses, key=lambda h: h.score / len(h))
|
| 193 |
+
output = best_hyp.seq
|
| 194 |
+
output = torch.LongTensor(output).unsqueeze(0)
|
| 195 |
+
score = best_hyp.score
|
| 196 |
+
|
| 197 |
+
return output, score
|
| 198 |
+
|
| 199 |
+
def forward(self, beam_size, batch_H, text, is_test):
|
| 200 |
+
if self.training:
|
| 201 |
+
return self.forward_greedy(batch_H, text)
|
| 202 |
+
else:
|
| 203 |
+
if beam_size > 1:
|
| 204 |
+
return self.forward_beam(batch_H, beam_size)
|
| 205 |
+
else:
|
| 206 |
+
return self.forward_greedy(batch_H, text, is_test = is_test)
|
| 207 |
+
|
HybridViT/module/component/seq_modeling/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .bilstm import *
|
| 2 |
+
from .vit_encoder import *
|
HybridViT/module/component/seq_modeling/addon_module/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .patchembed import *
|
HybridViT/module/component/seq_modeling/addon_module/patchembed.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from timm.models.layers.helpers import to_2tuple
|
| 6 |
+
from typing import Tuple, Union, List
|
| 7 |
+
|
| 8 |
+
class PatchEmbed(nn.Module):
|
| 9 |
+
""" Image to Patch Embedding
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, img_size=(224, 224), patch_size=16, in_chans=3, embed_dim=768):
|
| 12 |
+
super().__init__()
|
| 13 |
+
assert isinstance(img_size, tuple)
|
| 14 |
+
patch_size = to_2tuple(patch_size)
|
| 15 |
+
div_h, mod_h = divmod(img_size[0], patch_size[0])
|
| 16 |
+
div_w, mod_w = divmod(img_size[1], patch_size[1])
|
| 17 |
+
self.img_size = (patch_size[0]*(div_h + (1 if mod_h > 0 else 0)), \
|
| 18 |
+
patch_size[1]*(div_w + (1 if mod_w > 0 else 0)))
|
| 19 |
+
self.grid_size = (self.img_size[0] // patch_size[0], self.img_size[1] // patch_size[1])
|
| 20 |
+
self.patch_size = patch_size
|
| 21 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 22 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
_, _, H, W = x.shape
|
| 26 |
+
div_h, mod_h = divmod(H, self.patch_size[0])
|
| 27 |
+
div_w, mod_w = divmod(W, self.patch_size[1])
|
| 28 |
+
pad_H =self.patch_size[0]*(div_h + (1 if mod_h > 0 else 0)) - H
|
| 29 |
+
pad_W = self.patch_size[1]*(div_w + (1 if mod_w > 0 else 0)) - W
|
| 30 |
+
x = F.pad(x, (0, pad_W, 0 , pad_H))
|
| 31 |
+
assert x.shape[2] % self.patch_size[0] == 0 and x.shape[3] % self.patch_size[1] == 0
|
| 32 |
+
proj_x = self.proj(x).flatten(2).transpose(1, 2)
|
| 33 |
+
return proj_x, {'height': x.shape[2], 'width': x.shape[3]}, (x.shape[2] != self.img_size[0] or x.shape[3] != self.img_size[1])
|
| 34 |
+
|
| 35 |
+
class HybridEmbed(nn.Module):
|
| 36 |
+
""" CNN Feature Map Embedding
|
| 37 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
| 38 |
+
"""
|
| 39 |
+
def __init__(self, backbone, img_size: Tuple[int], patch_size=Union[List, int], feature_size=None, in_chans=3, embed_dim=768):
|
| 40 |
+
super().__init__()
|
| 41 |
+
assert isinstance(backbone, nn.Module)
|
| 42 |
+
if isinstance(patch_size, int):
|
| 43 |
+
patch_size = to_2tuple(patch_size)
|
| 44 |
+
else:
|
| 45 |
+
patch_size = tuple(patch_size)
|
| 46 |
+
self.img_size = img_size
|
| 47 |
+
self.patch_size = patch_size
|
| 48 |
+
self.backbone = backbone
|
| 49 |
+
if feature_size is None:
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
# NOTE Most reliable way of determining output dims is to run forward pass
|
| 52 |
+
training = backbone.training
|
| 53 |
+
if training:
|
| 54 |
+
backbone.eval()
|
| 55 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
| 56 |
+
if isinstance(o, (list, tuple)):
|
| 57 |
+
o = o[-1] # last feature if backbone outputs list/tuple of features
|
| 58 |
+
feature_size = o.shape[-2:]
|
| 59 |
+
feature_dim = o.shape[1]
|
| 60 |
+
backbone.train(training)
|
| 61 |
+
else:
|
| 62 |
+
feature_size = to_2tuple(feature_size)
|
| 63 |
+
if hasattr(self.backbone, 'feature_info'):
|
| 64 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
| 65 |
+
else:
|
| 66 |
+
feature_dim = self.backbone.num_features
|
| 67 |
+
|
| 68 |
+
assert feature_size[0] >= patch_size[0] and feature_size[1] >= patch_size[1]
|
| 69 |
+
|
| 70 |
+
div_h, mod_h = divmod(feature_size[0], patch_size[0])
|
| 71 |
+
div_w, mod_w = divmod(feature_size[1], patch_size[1])
|
| 72 |
+
|
| 73 |
+
self.feature_size = (patch_size[0]*(div_h + (1 if mod_h > 0 else 0)), patch_size[1]*(div_w + (1 if mod_w > 0 else 0)))
|
| 74 |
+
assert self.feature_size[0] % patch_size[0] == 0 and self.feature_size[1] % patch_size[1] == 0
|
| 75 |
+
self.grid_size = (self.feature_size[0] // patch_size[0], self.feature_size[1] // patch_size[1])
|
| 76 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 77 |
+
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
origin_size = x.shape[-2:]
|
| 81 |
+
x = self.backbone(x)
|
| 82 |
+
f_h, f_w = x.shape[2:]
|
| 83 |
+
# assert f_h >= self.patch_size[0] and f_w >= self.patch_size[1]
|
| 84 |
+
|
| 85 |
+
div_h, mod_h = divmod(f_h, self.patch_size[0])
|
| 86 |
+
div_w, mod_w = divmod(f_w, self.patch_size[1])
|
| 87 |
+
|
| 88 |
+
pad_H =self.patch_size[0]*(div_h + (1 if mod_h > 0 else 0)) - f_h
|
| 89 |
+
pad_W = self.patch_size[1]*(div_w + (1 if mod_w > 0 else 0)) - f_w
|
| 90 |
+
x = F.pad(x, (0, pad_W, 0 , pad_H))
|
| 91 |
+
|
| 92 |
+
assert x.shape[2] % self.patch_size[0] == 0 and x.shape[3] % self.patch_size[1] == 0
|
| 93 |
+
if isinstance(x, (list, tuple)):
|
| 94 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
| 95 |
+
|
| 96 |
+
proj_x = self.proj(x).flatten(2).transpose(1, 2)
|
| 97 |
+
return proj_x, (pad_W, pad_H), {'height': x.shape[2], 'width': x.shape[3]}, (x.shape[2] != self.feature_size[0] or x.shape[3] != self.feature_size[1])
|
| 98 |
+
|
| 99 |
+
class HybridEmbed1D(nn.Module):
|
| 100 |
+
""" CNN Feature Map Embedding which using 1D embed patching
|
| 101 |
+
from https://arxiv.org/pdf/2111.08314.pdf, which benefits for text recognition task.Check paper for more detail
|
| 102 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
| 103 |
+
"""
|
| 104 |
+
def __init__(self, backbone, img_size: Tuple[int], feature_size=None, patch_size=1, in_chans=3, embed_dim=768):
|
| 105 |
+
super().__init__()
|
| 106 |
+
assert isinstance(backbone, nn.Module)
|
| 107 |
+
self.img_size = img_size
|
| 108 |
+
self.backbone = backbone
|
| 109 |
+
self.embed_dim = embed_dim
|
| 110 |
+
if feature_size is None:
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
# NOTE Most reliable way of determining output dims is to run forward pass
|
| 113 |
+
training = backbone.training
|
| 114 |
+
if training:
|
| 115 |
+
backbone.eval()
|
| 116 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
| 117 |
+
if isinstance(o, (list, tuple)):
|
| 118 |
+
o = o[-1] # last feature if backbone outputs list/tuple of features
|
| 119 |
+
feature_size = o.shape[-2:]
|
| 120 |
+
feature_dim = o.shape[1]
|
| 121 |
+
backbone.train(training)
|
| 122 |
+
else:
|
| 123 |
+
feature_size = to_2tuple(feature_size)
|
| 124 |
+
if hasattr(self.backbone, 'feature_info'):
|
| 125 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
| 126 |
+
else:
|
| 127 |
+
feature_dim = self.backbone.num_features
|
| 128 |
+
|
| 129 |
+
self.window_width = patch_size
|
| 130 |
+
assert feature_size[1] >= self.window_width
|
| 131 |
+
div_w, mod_w = divmod(feature_size[1], self.window_width)
|
| 132 |
+
self.feature_size = (feature_size[0], self.window_width*(div_w + (1 if mod_w > 0 else 0)))
|
| 133 |
+
assert self.feature_size[1] % self.window_width == 0
|
| 134 |
+
self.grid_size = (1, self.feature_size[1] // self.window_width)
|
| 135 |
+
self.num_patches = self.grid_size[1]
|
| 136 |
+
self.proj = nn.Conv1d(feature_dim, embed_dim, kernel_size=self.window_width, stride=self.window_width, bias=True)
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
batch_size = x.shape[0]
|
| 140 |
+
x = self.backbone(x)
|
| 141 |
+
f_h, f_w = x.shape[2:]
|
| 142 |
+
assert f_w >= self.window_width
|
| 143 |
+
|
| 144 |
+
div_w, mod_w = divmod(f_w, self.window_width)
|
| 145 |
+
pad_W = self.window_width*(div_w + (1 if mod_w > 0 else 0)) - f_w
|
| 146 |
+
|
| 147 |
+
x = F.pad(x, (0, pad_W))
|
| 148 |
+
assert x.shape[3] % self.window_width == 0
|
| 149 |
+
|
| 150 |
+
if isinstance(x, (list, tuple)):
|
| 151 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
| 152 |
+
|
| 153 |
+
proj_x = torch.zeros(batch_size, self.embed_dim, f_h, x.shape[3]//self.window_width, device=x.device, dtype=x.dtype)
|
| 154 |
+
|
| 155 |
+
for i in range(f_h):
|
| 156 |
+
proj = self.proj(x[:, :, i, :])
|
| 157 |
+
proj_x[:, :, i, :] = proj
|
| 158 |
+
|
| 159 |
+
proj_x = proj_x.mean(dim=2).transpose(1, 2) #BCHW->BCW
|
| 160 |
+
|
| 161 |
+
return proj_x, (pad_W, ), {'height': x.shape[2], 'width': x.shape[3]}, (x.shape[2] != self.feature_size[0] or x.shape[3] != self.feature_size[1])
|
HybridViT/module/component/seq_modeling/bilstm.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
__all__ = ['BiLSTM_Seq_Modeling', 'BidirectionalLSTM']
|
| 4 |
+
|
| 5 |
+
class BidirectionalLSTM(nn.Module):
|
| 6 |
+
def __init__(self, input_size, hidden_size, output_size):
|
| 7 |
+
super(BidirectionalLSTM, self).__init__()
|
| 8 |
+
self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
|
| 9 |
+
self.linear = nn.Linear(hidden_size * 2, output_size)
|
| 10 |
+
|
| 11 |
+
def forward(self, input):
|
| 12 |
+
"""
|
| 13 |
+
input : visual feature [batch_size x T x input_size]
|
| 14 |
+
output : contextual feature [batch_size x T x output_size]
|
| 15 |
+
"""
|
| 16 |
+
self.rnn.flatten_parameters()
|
| 17 |
+
recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
|
| 18 |
+
output = self.linear(recurrent) # batch_size x T x output_size
|
| 19 |
+
return output
|
| 20 |
+
|
| 21 |
+
class BiLSTM_Seq_Modeling(nn.Module):
|
| 22 |
+
def __init__(self, num_layers, input_size, hidden_size, output_size):
|
| 23 |
+
super(BiLSTM_Seq_Modeling, self).__init__()
|
| 24 |
+
self.num_layers = num_layers
|
| 25 |
+
layers = []
|
| 26 |
+
layers += [BidirectionalLSTM(input_size, hidden_size, hidden_size)]
|
| 27 |
+
for i in range(num_layers-2):
|
| 28 |
+
layers.append(BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
|
| 29 |
+
layers.append(BidirectionalLSTM(hidden_size, hidden_size, output_size))
|
| 30 |
+
self.lstm = nn.Sequential(*layers)
|
| 31 |
+
|
| 32 |
+
def forward(self, input):
|
| 33 |
+
return self.lstm(input)
|
HybridViT/module/component/seq_modeling/vit/utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 7 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 8 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 9 |
+
def norm_cdf(x):
|
| 10 |
+
# Computes standard normal cumulative distribution function
|
| 11 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 12 |
+
|
| 13 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 14 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 15 |
+
"The distribution of values may be incorrect.",
|
| 16 |
+
stacklevel=2)
|
| 17 |
+
|
| 18 |
+
with torch.no_grad():
|
| 19 |
+
# Values are generated by using a truncated uniform distribution and
|
| 20 |
+
# then using the inverse CDF for the normal distribution.
|
| 21 |
+
# Get upper and lower cdf values
|
| 22 |
+
l = norm_cdf((a - mean) / std)
|
| 23 |
+
u = norm_cdf((b - mean) / std)
|
| 24 |
+
|
| 25 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 26 |
+
# [2l-1, 2u-1].
|
| 27 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 28 |
+
|
| 29 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 30 |
+
# standard normal
|
| 31 |
+
tensor.erfinv_()
|
| 32 |
+
|
| 33 |
+
# Transform to proper mean, std
|
| 34 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 35 |
+
tensor.add_(mean)
|
| 36 |
+
|
| 37 |
+
# Clamp to ensure it's in the proper range
|
| 38 |
+
tensor.clamp_(min=a, max=b)
|
| 39 |
+
return tensor
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 43 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
| 44 |
+
normal distribution. The values are effectively drawn from the
|
| 45 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
| 46 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
| 47 |
+
the bounds. The method used for generating the random values works
|
| 48 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
| 49 |
+
Args:
|
| 50 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 51 |
+
mean: the mean of the normal distribution
|
| 52 |
+
std: the standard deviation of the normal distribution
|
| 53 |
+
a: the minimum cutoff value
|
| 54 |
+
b: the maximum cutoff value
|
| 55 |
+
Examples:
|
| 56 |
+
>>> w = torch.empty(3, 5)
|
| 57 |
+
>>> nn.init.trunc_normal_(w)
|
| 58 |
+
"""
|
| 59 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
HybridViT/module/component/seq_modeling/vit/vision_transformer.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from functools import partial
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
from ...common import DropPath
|
| 6 |
+
from .utils import trunc_normal_
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Mlp(nn.Module):
|
| 10 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 11 |
+
super().__init__()
|
| 12 |
+
out_features = out_features or in_features
|
| 13 |
+
hidden_features = hidden_features or in_features
|
| 14 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 15 |
+
self.act = act_layer()
|
| 16 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 17 |
+
self.drop = nn.Dropout(drop)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
x = self.fc1(x)
|
| 21 |
+
x = self.act(x)
|
| 22 |
+
x = self.drop(x)
|
| 23 |
+
x = self.fc2(x)
|
| 24 |
+
x = self.drop(x)
|
| 25 |
+
return x
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ConvFFN(nn.Module):
|
| 29 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 30 |
+
super().__init__(*args, **kwargs)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Attention(nn.Module):
|
| 34 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.num_heads = num_heads
|
| 37 |
+
head_dim = dim // num_heads
|
| 38 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 39 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 40 |
+
|
| 41 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 42 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 43 |
+
self.proj = nn.Linear(dim, dim)
|
| 44 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
B, N, C = x.shape
|
| 48 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 49 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 50 |
+
|
| 51 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 52 |
+
attn = attn.softmax(dim=-1)
|
| 53 |
+
attn = self.attn_drop(attn)
|
| 54 |
+
|
| 55 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 56 |
+
x = self.proj(x)
|
| 57 |
+
x = self.proj_drop(x)
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Block(nn.Module):
|
| 62 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 63 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.norm1 = norm_layer(dim)
|
| 66 |
+
self.attn = Attention(
|
| 67 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 68 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 69 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 70 |
+
self.norm2 = norm_layer(dim)
|
| 71 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 72 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 76 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class VisionTransformer(nn.Module):
|
| 81 |
+
""" Vision Transformer
|
| 82 |
+
|
| 83 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
| 84 |
+
https://arxiv.org/abs/2010.11929
|
| 85 |
+
"""
|
| 86 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
| 87 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
| 88 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None):
|
| 89 |
+
"""
|
| 90 |
+
Args:
|
| 91 |
+
img_size (int, tuple): input image size
|
| 92 |
+
patch_size (int, tuple): patch size
|
| 93 |
+
in_chans (int): number of input channels
|
| 94 |
+
num_classes (int): number of classes for classification head
|
| 95 |
+
embed_dim (int): embedding dimension
|
| 96 |
+
depth (int): depth of transformer
|
| 97 |
+
num_heads (int): number of attention heads
|
| 98 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 99 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 100 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
| 101 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
| 102 |
+
drop_rate (float): dropout rate
|
| 103 |
+
attn_drop_rate (float): attention dropout rate
|
| 104 |
+
drop_path_rate (float): stochastic depth rate
|
| 105 |
+
hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
|
| 106 |
+
norm_layer: (nn.Module): normalization layer
|
| 107 |
+
"""
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.num_classes = num_classes
|
| 110 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 111 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 112 |
+
|
| 113 |
+
self.patch_embed = None
|
| 114 |
+
num_patches = getattr(self.patch_embed, 'num_patches', 128)
|
| 115 |
+
|
| 116 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 117 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 118 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 119 |
+
|
| 120 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 121 |
+
self.blocks = nn.ModuleList([
|
| 122 |
+
Block(
|
| 123 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 124 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
| 125 |
+
for i in range(depth)])
|
| 126 |
+
self.norm = norm_layer(embed_dim)
|
| 127 |
+
|
| 128 |
+
# Representation layer
|
| 129 |
+
if representation_size:
|
| 130 |
+
self.num_features = representation_size
|
| 131 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
| 132 |
+
('fc', nn.Linear(embed_dim, representation_size)),
|
| 133 |
+
('act', nn.Tanh())
|
| 134 |
+
]))
|
| 135 |
+
else:
|
| 136 |
+
self.pre_logits = nn.Identity()
|
| 137 |
+
|
| 138 |
+
# Classifier head
|
| 139 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 140 |
+
|
| 141 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 142 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 143 |
+
self.apply(self._init_weights)
|
| 144 |
+
|
| 145 |
+
def _init_weights(self, m):
|
| 146 |
+
if isinstance(m, nn.Linear):
|
| 147 |
+
trunc_normal_(m.weight, std=.02)
|
| 148 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 149 |
+
nn.init.constant_(m.bias, 0)
|
| 150 |
+
elif isinstance(m, nn.LayerNorm):
|
| 151 |
+
nn.init.constant_(m.bias, 0)
|
| 152 |
+
nn.init.constant_(m.weight, 1.0)
|
| 153 |
+
|
| 154 |
+
@torch.jit.ignore
|
| 155 |
+
def no_weight_decay(self):
|
| 156 |
+
return {'pos_embed', 'cls_token'}
|
| 157 |
+
|
| 158 |
+
def get_classifier(self):
|
| 159 |
+
return self.head
|
| 160 |
+
|
| 161 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
| 162 |
+
self.num_classes = num_classes
|
| 163 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 164 |
+
|
| 165 |
+
def forward_features(self, x):
|
| 166 |
+
B = x.shape[0]
|
| 167 |
+
x = self.patch_embed(x)
|
| 168 |
+
|
| 169 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 170 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 171 |
+
x = x + self.pos_embed
|
| 172 |
+
x = self.pos_drop(x)
|
| 173 |
+
|
| 174 |
+
for blk in self.blocks:
|
| 175 |
+
x = blk(x)
|
| 176 |
+
|
| 177 |
+
x = self.norm(x)[:, 0]
|
| 178 |
+
x = self.pre_logits(x)
|
| 179 |
+
return x
|
| 180 |
+
|
| 181 |
+
def forward(self, x):
|
| 182 |
+
x = self.forward_features(x)
|
| 183 |
+
x = self.head(x)
|
| 184 |
+
return x
|
HybridViT/module/component/seq_modeling/vit_encoder.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from .vit.utils import trunc_normal_
|
| 5 |
+
from .vit.vision_transformer import VisionTransformer
|
| 6 |
+
from ..feature_extractor.clova_impl import ResNet_FeatureExtractor
|
| 7 |
+
from .addon_module import *
|
| 8 |
+
from ..common.mae_posembed import get_2d_sincos_pos_embed
|
| 9 |
+
|
| 10 |
+
__all__ = ['ViTEncoder', 'ViTEncoderV2', 'ViTEncoderV3', 'TRIGBaseEncoder', 'create_vit_modeling']
|
| 11 |
+
|
| 12 |
+
class ViTEncoder(VisionTransformer):
|
| 13 |
+
'''
|
| 14 |
+
'''
|
| 15 |
+
def __init__(self, *args, **kwargs):
|
| 16 |
+
super().__init__(*args, **kwargs)
|
| 17 |
+
|
| 18 |
+
if kwargs['hybrid_backbone'] is None:
|
| 19 |
+
self.patch_embed = PatchEmbed(
|
| 20 |
+
img_size=kwargs['img_size'],
|
| 21 |
+
in_chans=kwargs['in_chans'],
|
| 22 |
+
patch_size=kwargs['patch_size'],
|
| 23 |
+
embed_dim=kwargs['embed_dim'],
|
| 24 |
+
)
|
| 25 |
+
else:
|
| 26 |
+
self.patch_embed = HybridEmbed(
|
| 27 |
+
backbone=kwargs['hybrid_backbone'],
|
| 28 |
+
img_size=kwargs['img_size'],
|
| 29 |
+
in_chans=kwargs['in_chans'],
|
| 30 |
+
patch_size=kwargs['patch_size'],
|
| 31 |
+
embed_dim=kwargs['embed_dim'],
|
| 32 |
+
)
|
| 33 |
+
num_patches = self.patch_embed.num_patches
|
| 34 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']))
|
| 35 |
+
self.emb_height = self.patch_embed.grid_size[0]
|
| 36 |
+
self.emb_width = self.patch_embed.grid_size[1]
|
| 37 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 38 |
+
self.apply(self._init_weights)
|
| 39 |
+
|
| 40 |
+
def reset_classifier(self, num_classes):
|
| 41 |
+
self.num_classes = num_classes
|
| 42 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 43 |
+
|
| 44 |
+
def interpolating_pos_embedding(self, embedding, height, width):
|
| 45 |
+
"""
|
| 46 |
+
Source:
|
| 47 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
| 48 |
+
"""
|
| 49 |
+
npatch = embedding.shape[1] - 1
|
| 50 |
+
N = self.pos_embed.shape[1] - 1
|
| 51 |
+
if npatch == N and height == width:
|
| 52 |
+
return self.pos_embed
|
| 53 |
+
|
| 54 |
+
class_pos_embedding = self.pos_embed[:, 0]
|
| 55 |
+
patch_pos_embedding = self.pos_embed[:, 1:]
|
| 56 |
+
dim = self.pos_embed.shape[-1]
|
| 57 |
+
|
| 58 |
+
h0 = height // self.patch_embed.patch_size[0]
|
| 59 |
+
w0 = width // self.patch_embed.patch_size[1]
|
| 60 |
+
#add a small number to avo_id floating point error
|
| 61 |
+
# https://github.com/facebookresearch/dino/issues/8
|
| 62 |
+
|
| 63 |
+
h0 = h0 + 0.1
|
| 64 |
+
w0 = w0 + 0.1
|
| 65 |
+
|
| 66 |
+
patch_pos_embedding = nn.functional.interpolate(
|
| 67 |
+
patch_pos_embedding.reshape(1, self.emb_height, self.emb_width, dim).permute(0, 3, 1, 2),
|
| 68 |
+
scale_factor=(h0 / self.emb_height, w0 / self.emb_width),
|
| 69 |
+
mode='bicubic',
|
| 70 |
+
align_corners=False
|
| 71 |
+
)
|
| 72 |
+
assert int(h0) == patch_pos_embedding.shape[-2] and int(w0) == patch_pos_embedding.shape[-1]
|
| 73 |
+
patch_pos_embedding = patch_pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 74 |
+
class_pos_embedding = class_pos_embedding.unsqueeze(0)
|
| 75 |
+
|
| 76 |
+
return torch.cat((class_pos_embedding, patch_pos_embedding), dim=1)
|
| 77 |
+
|
| 78 |
+
def forward_features(self, x):
|
| 79 |
+
B, C, _, _ = x.shape
|
| 80 |
+
|
| 81 |
+
x, pad_info, size, interpolating_pos = self.patch_embed(x)
|
| 82 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 83 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 84 |
+
|
| 85 |
+
if interpolating_pos:
|
| 86 |
+
x = x + self.interpolating_pos_embedding(x, size['height'], size['width'])
|
| 87 |
+
else:
|
| 88 |
+
x = x + self.pos_embed
|
| 89 |
+
|
| 90 |
+
x = self.pos_drop(x)
|
| 91 |
+
|
| 92 |
+
for blk in self.blocks:
|
| 93 |
+
x = blk(x)
|
| 94 |
+
|
| 95 |
+
x = self.norm(x)
|
| 96 |
+
|
| 97 |
+
return x, pad_info, size
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class TRIGBaseEncoder(ViTEncoder):
|
| 101 |
+
'''
|
| 102 |
+
https://arxiv.org/pdf/2111.08314.pdf
|
| 103 |
+
'''
|
| 104 |
+
def __init__(self, *args, **kwargs):
|
| 105 |
+
super().__init__(*args, **kwargs)
|
| 106 |
+
self.patch_embed = HybridEmbed1D(
|
| 107 |
+
backbone=kwargs['hybrid_backbone'],
|
| 108 |
+
img_size=kwargs['img_size'],
|
| 109 |
+
in_chans=kwargs['in_chans'],
|
| 110 |
+
patch_size=kwargs['patch_size'],
|
| 111 |
+
embed_dim=kwargs['embed_dim'],
|
| 112 |
+
)
|
| 113 |
+
num_patches = self.patch_embed.num_patches
|
| 114 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']))
|
| 115 |
+
self.emb_height = 1
|
| 116 |
+
self.emb_width = self.patch_embed.grid_size[1]
|
| 117 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 118 |
+
self.apply(self._init_weights)
|
| 119 |
+
|
| 120 |
+
def interpolating_pos_embedding(self, embedding, height, width):
|
| 121 |
+
"""
|
| 122 |
+
Source:
|
| 123 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
| 124 |
+
"""
|
| 125 |
+
npatch = embedding.shape[1] - 1
|
| 126 |
+
N = self.pos_embed.shape[1] - 1
|
| 127 |
+
if npatch == N and height == width:
|
| 128 |
+
return self.pos_embed
|
| 129 |
+
|
| 130 |
+
class_pos_embedding = self.pos_embed[:, 0]
|
| 131 |
+
patch_pos_embedding = self.pos_embed[:, 1:]
|
| 132 |
+
dim = self.pos_embed.shape[-1]
|
| 133 |
+
|
| 134 |
+
w0 = width // self.patch_embed.window_width
|
| 135 |
+
|
| 136 |
+
#add a small number to avoid floating point error
|
| 137 |
+
# https://github.com/facebookresearch/dino/issues/8
|
| 138 |
+
|
| 139 |
+
w0 = w0 + 0.1
|
| 140 |
+
|
| 141 |
+
patch_pos_embedding = nn.functional.interpolate(
|
| 142 |
+
patch_pos_embedding.reshape(1, self.emb_height, self.emb_width, dim).permute(0, 3, 1, 2),
|
| 143 |
+
scale_factor=(1, w0 / self.emb_width),
|
| 144 |
+
mode='bicubic',
|
| 145 |
+
align_corners=False
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
assert int(w0) == patch_pos_embedding.shape[-1]
|
| 149 |
+
patch_pos_embedding = patch_pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 150 |
+
class_pos_embedding = class_pos_embedding.unsqueeze(0)
|
| 151 |
+
|
| 152 |
+
return torch.cat((class_pos_embedding, patch_pos_embedding), dim=1)
|
| 153 |
+
|
| 154 |
+
def forward_features(self, x):
|
| 155 |
+
B, _, _, _ = x.shape
|
| 156 |
+
x, padinfo, size, interpolating_pos = self.patch_embed(x)
|
| 157 |
+
|
| 158 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 159 |
+
|
| 160 |
+
x = torch.cat((cls_tokens, x), dim=1) #cls_tokens is init_embedding in TRIG paper
|
| 161 |
+
|
| 162 |
+
if interpolating_pos:
|
| 163 |
+
x = x + self.interpolating_pos_embedding(x, size['height'], size['width'])
|
| 164 |
+
else:
|
| 165 |
+
x = x + self.pos_embed
|
| 166 |
+
|
| 167 |
+
x = self.pos_drop(x)
|
| 168 |
+
|
| 169 |
+
for blk in self.blocks:
|
| 170 |
+
x = blk(x)
|
| 171 |
+
|
| 172 |
+
x = self.norm(x)
|
| 173 |
+
|
| 174 |
+
return x, padinfo, size
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class ViTEncoderV2(ViTEncoder):
|
| 178 |
+
def forward(self, x):
|
| 179 |
+
B, _, _, _ = x.shape
|
| 180 |
+
|
| 181 |
+
x, pad_info, size, _ = self.patch_embed(x)
|
| 182 |
+
_, numpatches, *_ = x.shape
|
| 183 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 184 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 185 |
+
|
| 186 |
+
x = x + self.pos_embed[:, :(numpatches + 1)]
|
| 187 |
+
x = self.pos_drop(x)
|
| 188 |
+
|
| 189 |
+
for blk in self.blocks:
|
| 190 |
+
x = blk(x)
|
| 191 |
+
|
| 192 |
+
x = self.norm(x)
|
| 193 |
+
|
| 194 |
+
return x, pad_info, size
|
| 195 |
+
|
| 196 |
+
class ViTEncoderV3(ViTEncoder):
|
| 197 |
+
def __init__(self, *args, **kwargs):
|
| 198 |
+
super().__init__(*args, **kwargs)
|
| 199 |
+
if hasattr(self, 'pos_embed'):
|
| 200 |
+
del self.pos_embed
|
| 201 |
+
num_patches = self.patch_embed.num_patches
|
| 202 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, kwargs['embed_dim']), requires_grad=False)
|
| 203 |
+
self.initialize_posembed()
|
| 204 |
+
|
| 205 |
+
def initialize_posembed(self):
|
| 206 |
+
pos_embed = get_2d_sincos_pos_embed(
|
| 207 |
+
self.pos_embed.shape[-1],
|
| 208 |
+
self.patch_embed.grid_size[0],
|
| 209 |
+
self.patch_embed.grid_size[1],
|
| 210 |
+
cls_token=True
|
| 211 |
+
)
|
| 212 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 213 |
+
|
| 214 |
+
def forward(self, x):
|
| 215 |
+
B, _, _, _ = x.shape
|
| 216 |
+
|
| 217 |
+
x, pad_info, size, _ = self.patch_embed(x)
|
| 218 |
+
_, numpatches, *_ = x.shape
|
| 219 |
+
|
| 220 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 221 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 222 |
+
|
| 223 |
+
x = x + self.pos_embed[:, :(numpatches + 1)]
|
| 224 |
+
x = self.pos_drop(x)
|
| 225 |
+
|
| 226 |
+
for blk in self.blocks:
|
| 227 |
+
x = blk(x)
|
| 228 |
+
|
| 229 |
+
x = self.norm(x)
|
| 230 |
+
|
| 231 |
+
return x, pad_info, size
|
| 232 |
+
|
| 233 |
+
def create_vit_modeling(opt):
|
| 234 |
+
seq_modeling = opt['SequenceModeling']['params']
|
| 235 |
+
if seq_modeling['backbone'] is not None:
|
| 236 |
+
if seq_modeling['backbone']['name'] == 'resnet':
|
| 237 |
+
param_kwargs = dict()
|
| 238 |
+
if seq_modeling['backbone'].get('pretrained', None) is not None:
|
| 239 |
+
param_kwargs['pretrained'] = seq_modeling['backbone']['pretrained']
|
| 240 |
+
if seq_modeling['backbone'].get('weight_dir', None) is not None:
|
| 241 |
+
param_kwargs['weight_dir'] = seq_modeling['backbone']['weight_dir']
|
| 242 |
+
print('kwargs', param_kwargs)
|
| 243 |
+
|
| 244 |
+
backbone = ResNet_FeatureExtractor(
|
| 245 |
+
seq_modeling['backbone']['input_channel'],
|
| 246 |
+
seq_modeling['backbone']['output_channel'],
|
| 247 |
+
seq_modeling['backbone']['gcb'],
|
| 248 |
+
**param_kwargs
|
| 249 |
+
)
|
| 250 |
+
elif seq_modeling['backbone']['name'] == 'cnn':
|
| 251 |
+
backbone = None
|
| 252 |
+
else: backbone = None
|
| 253 |
+
max_dimension = (opt['imgH'], opt['max_dimension'][1]) if opt['imgH'] else opt['max_dimension']
|
| 254 |
+
if seq_modeling['patching_style'] == '2d':
|
| 255 |
+
if seq_modeling.get('fix_embed', False):
|
| 256 |
+
encoder = ViTEncoderV3
|
| 257 |
+
else:
|
| 258 |
+
if not seq_modeling.get('interpolate_embed', True):
|
| 259 |
+
encoder = ViTEncoderV2
|
| 260 |
+
else:
|
| 261 |
+
encoder = ViTEncoder
|
| 262 |
+
else:
|
| 263 |
+
encoder = TRIGBaseEncoder
|
| 264 |
+
|
| 265 |
+
encoder_seq_modeling = encoder(
|
| 266 |
+
img_size=max_dimension,
|
| 267 |
+
patch_size=seq_modeling['patch_size'],
|
| 268 |
+
in_chans=seq_modeling['input_channel'],
|
| 269 |
+
depth=seq_modeling['depth'],
|
| 270 |
+
num_classes=0,
|
| 271 |
+
embed_dim=seq_modeling['hidden_size'],
|
| 272 |
+
num_heads=seq_modeling['num_heads'],
|
| 273 |
+
hybrid_backbone=backbone
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
return encoder_seq_modeling
|
HybridViT/module/converter/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .builder import create_converter
|
| 2 |
+
from .attn_converter import AttnLabelConverter
|
| 3 |
+
from .tfm_converter import TFMLabelConverter
|
HybridViT/module/converter/attn_converter.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class AttnLabelConverter(object):
|
| 5 |
+
""" Convert between text-label and text-index """
|
| 6 |
+
|
| 7 |
+
list_token = ['[GO]', '[s]', '[UNK]']
|
| 8 |
+
def __init__(self, character, device):
|
| 9 |
+
list_character = character
|
| 10 |
+
self.character = AttnLabelConverter.list_token + list_character
|
| 11 |
+
|
| 12 |
+
self.device = device
|
| 13 |
+
self.dict = {}
|
| 14 |
+
for i, char in enumerate(self.character):
|
| 15 |
+
self.dict[char] = i
|
| 16 |
+
self.ignore_idx = self.dict['[GO]']
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def START() -> int:
|
| 20 |
+
return AttnLabelConverter.list_token.index('[GO]')
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def END() -> int:
|
| 24 |
+
return AttnLabelConverter.list_token.index('[s]')
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def UNK() -> int:
|
| 28 |
+
return AttnLabelConverter.list_token.index('[UNK]')
|
| 29 |
+
|
| 30 |
+
def encode(self, text, batch_max_length=25):
|
| 31 |
+
length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
|
| 32 |
+
# batch_max_length = max(length) # this is not allowed for multi-gpu setting
|
| 33 |
+
batch_max_length += 1
|
| 34 |
+
# additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
|
| 35 |
+
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
|
| 36 |
+
for i, t in enumerate(text):
|
| 37 |
+
text = list(t)
|
| 38 |
+
|
| 39 |
+
if len(text) > batch_max_length:
|
| 40 |
+
text = text[:(batch_max_length-1)]
|
| 41 |
+
|
| 42 |
+
text.append('[s]')
|
| 43 |
+
text = [self.dict[char] if char in self.dict else self.dict['[UNK]'] for char in text]
|
| 44 |
+
|
| 45 |
+
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
|
| 46 |
+
return (batch_text.to(self.device), torch.IntTensor(length).to(self.device))
|
| 47 |
+
|
| 48 |
+
def decode(self, text_index, token_level='word'):
|
| 49 |
+
""" convert text-index into text-label. """
|
| 50 |
+
texts = []
|
| 51 |
+
batch_size = text_index.shape[0]
|
| 52 |
+
for index in range(batch_size):
|
| 53 |
+
if token_level == 'word':
|
| 54 |
+
text = ' '.join([self.character[i] for i in text_index[index, :]])
|
| 55 |
+
else:
|
| 56 |
+
text = ''.join([self.character[i] for i in text_index[index, :]])
|
| 57 |
+
texts.append(text)
|
| 58 |
+
return texts
|
| 59 |
+
|
| 60 |
+
def detokenize(self, token_ids):
|
| 61 |
+
"""convert token ids to list of token"""
|
| 62 |
+
b_toks = []
|
| 63 |
+
for tok in token_ids:
|
| 64 |
+
toks = []
|
| 65 |
+
for i in tok:
|
| 66 |
+
if self.character[i] == '[s]':
|
| 67 |
+
break
|
| 68 |
+
toks.append(self.character[i])
|
| 69 |
+
b_toks.append(toks)
|
| 70 |
+
|
| 71 |
+
return b_toks
|
HybridViT/module/converter/builder.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .attn_converter import AttnLabelConverter
|
| 2 |
+
|
| 3 |
+
def create_converter(config, device):
|
| 4 |
+
if 'Attn' in config['Prediction']['name']:
|
| 5 |
+
converter = AttnLabelConverter(config['character'], device)
|
| 6 |
+
return converter
|
HybridViT/module/converter/tfm_converter.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TFMLabelConverter(object):
|
| 6 |
+
""" Convert between text-label and text-index """
|
| 7 |
+
|
| 8 |
+
list_token = ['[PAD]', '[GO]', '[s]', '[UNK]']
|
| 9 |
+
def __init__(self, character, device):
|
| 10 |
+
list_character = character
|
| 11 |
+
self.character = TFMLabelConverter.list_token + list_character
|
| 12 |
+
|
| 13 |
+
self.device = device
|
| 14 |
+
self.dict = {}
|
| 15 |
+
for i, char in enumerate(self.character):
|
| 16 |
+
self.dict[char] = i
|
| 17 |
+
self.ignore_idx = self.dict['[PAD]']
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def START() -> int:
|
| 21 |
+
return TFMLabelConverter.list_token.index('[GO]')
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def END() -> int:
|
| 25 |
+
return TFMLabelConverter.list_token.index('[s]')
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def UNK() -> int:
|
| 29 |
+
return TFMLabelConverter.list_token.index('[UNK]')
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def PAD() -> int:
|
| 33 |
+
return TFMLabelConverter.list_token.index('[PAD]')
|
| 34 |
+
|
| 35 |
+
def encode(self, text, batch_max_length=25):
|
| 36 |
+
length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
|
| 37 |
+
batch_max_length += 1
|
| 38 |
+
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(self.ignore_idx)
|
| 39 |
+
for i, t in enumerate(text):
|
| 40 |
+
text = list(t)
|
| 41 |
+
|
| 42 |
+
if len(text) > batch_max_length:
|
| 43 |
+
text = text[:(batch_max_length-1)]
|
| 44 |
+
|
| 45 |
+
text.append('[s]')
|
| 46 |
+
text = [self.dict[char] if char in self.dict else self.dict['[UNK]'] for char in text]
|
| 47 |
+
batch_text[i][0] = torch.LongTensor([self.dict['[GO]']])
|
| 48 |
+
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
|
| 49 |
+
return (batch_text.to(self.device), torch.IntTensor(length).to(self.device))
|
| 50 |
+
|
| 51 |
+
def decode(self, text_index, token_level='word'):
|
| 52 |
+
""" convert text-index into text-label. """
|
| 53 |
+
texts = []
|
| 54 |
+
batch_size = text_index.shape[0]
|
| 55 |
+
for index in range(batch_size):
|
| 56 |
+
if token_level == 'word':
|
| 57 |
+
text = ' '.join([self.character[i] for i in text_index[index, :]])
|
| 58 |
+
else:
|
| 59 |
+
text = ''.join([self.character[i] for i in text_index[index, :]])
|
| 60 |
+
texts.append(text)
|
| 61 |
+
return texts
|
| 62 |
+
|
| 63 |
+
def detokenize(self, token_ids):
|
| 64 |
+
"""convert token ids to list of token"""
|
| 65 |
+
b_toks = []
|
| 66 |
+
for tok in token_ids:
|
| 67 |
+
toks = []
|
| 68 |
+
for i in tok:
|
| 69 |
+
if self.character[i] == '[s]':
|
| 70 |
+
break
|
| 71 |
+
toks.append(self.character[i])
|
| 72 |
+
b_toks.append(toks)
|
| 73 |
+
|
| 74 |
+
return b_toks
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
vocab = ['S', 'ố', ' ', '2', '5', '3', 'đ', 'ư', 'ờ', 'n', 'g', 'T', 'r', 'ầ', 'P', 'h', 'ú', ',', 'ị', 't', 'ấ', 'N', 'a', 'm', 'á', 'c', 'H', 'u', 'y', 'ệ', 'ả', 'i', 'D', 'ơ', '8', '9', 'Đ', 'B', 'ộ', 'L', 'ĩ', '6', 'Q', 'ậ', 'ì', 'ạ', 'ồ', 'C', 'í', 'M', '4', 'E', '/', 'K', 'p', '1', 'A', 'x', 'ặ', 'ễ', '0', 'â', 'à', 'ế', 'ừ', 'ê', '-', '7', 'o', 'V', 'ô', 'ã', 'G', 'ớ', 'Y', 'I', 'ề', 'ò', 'l', 'R', 'ỹ', 'ủ', 'X', "'", 'e', 'ắ', 'ổ', 'ằ', 'k', 's', '.', 'ợ', 'ù', 'ứ', 'ă', 'ỳ', 'ẵ', 'ý', 'ó', 'ẩ', 'ọ', 'J', 'ũ', 'ữ', 'ự', 'õ', 'ỉ', 'ỏ', 'v', 'd', 'Â', 'W', 'U', 'O', 'é', 'ở', 'ỷ', '(', ')', 'ử', 'è', 'ể', 'ụ', 'ỗ', 'F', 'q', 'ẻ', 'ỡ', 'b', 'ỵ', 'Ứ', '#', 'ẽ', 'Ô', 'Ê', 'Ơ', '+', 'z', 'Ấ', 'w', 'Z', '&', 'Á', '~', 'f', 'Ạ', 'Ắ', 'j', ':', 'Ă', '<', '>', 'ẹ', '_', 'À', 'Ị', 'Ư', 'Ễ']
|
| 78 |
+
text = [
|
| 79 |
+
"190B Trần Quang Khải, Phường Tân Định, Quận 1, TP Hồ Chí Minh",
|
| 80 |
+
"164/2B, Quốc lộ 1A, Phường Lê Bình, Quận Cái Răng, Cần Thơ",
|
| 81 |
+
"Cẩm Huy, Huyện Cẩm Xuyên, Hà Tĩnh"
|
| 82 |
+
]
|
| 83 |
+
tfm_convert = TFMLabelConverter(vocab, 'cpu')
|
| 84 |
+
texts, lengths = tfm_convert.encode(text, 70)
|
| 85 |
+
print(texts)
|
| 86 |
+
for text in texts:
|
| 87 |
+
print('Encode', text)
|
| 88 |
+
text = text.unsqueeze(0)
|
| 89 |
+
decode_text = tfm_convert.decode(text, 'char')
|
| 90 |
+
print('Decode', decode_text)
|
HybridViT/recog_flow.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Any
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from timm.models.resnetv2 import ResNetV2
|
| 11 |
+
|
| 12 |
+
from .recognizers.build_model import Model
|
| 13 |
+
from .module.converter import AttnLabelConverter, TFMLabelConverter
|
| 14 |
+
from .helper import resize
|
| 15 |
+
|
| 16 |
+
class MathRecognition(object):
|
| 17 |
+
def __init__(self, config, resizer):
|
| 18 |
+
self.args = config
|
| 19 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 20 |
+
self.args["device"] = device
|
| 21 |
+
self.device = device
|
| 22 |
+
self._prepare_vocab()
|
| 23 |
+
self.model = self._get_model()
|
| 24 |
+
self.resizer=resizer
|
| 25 |
+
|
| 26 |
+
def _mapping_ckpt(self, state_dict):
|
| 27 |
+
new_state_dict = OrderedDict()
|
| 28 |
+
|
| 29 |
+
for name, param in state_dict.items():
|
| 30 |
+
if name.startswith('Transformation'):
|
| 31 |
+
continue
|
| 32 |
+
elif name.startswith('FeatureExtraction'):
|
| 33 |
+
new_name = name.replace('FeatureExtraction', 'featextractor.FeatureExtraction')
|
| 34 |
+
new_state_dict[new_name] = param
|
| 35 |
+
elif name.startswith('SequenceModeling'):
|
| 36 |
+
new_name = name.replace('SequenceModeling', 'seqmodeler.SequenceModeling')
|
| 37 |
+
new_state_dict[new_name] = param
|
| 38 |
+
elif name.startswith('Prediction'):
|
| 39 |
+
new_name = name.replace('Prediction', 'predicter.Prediction')
|
| 40 |
+
new_state_dict[new_name] = param
|
| 41 |
+
else:
|
| 42 |
+
new_state_dict[name] = param
|
| 43 |
+
|
| 44 |
+
return new_state_dict
|
| 45 |
+
|
| 46 |
+
def _get_model(self):
|
| 47 |
+
model = Model(self.args)
|
| 48 |
+
state_dict = torch.load(self.args["weight_path"], map_location='cpu')
|
| 49 |
+
new_state_dict = self._mapping_ckpt(state_dict)
|
| 50 |
+
model.load_state_dict(new_state_dict)
|
| 51 |
+
model=model.eval()
|
| 52 |
+
|
| 53 |
+
if self.device == 'cuda':
|
| 54 |
+
num_gpu = torch.cuda.device_count()
|
| 55 |
+
if num_gpu > 1:
|
| 56 |
+
model = nn.DataParallel(model).to(self.device)
|
| 57 |
+
else:
|
| 58 |
+
model.to(self.device)
|
| 59 |
+
|
| 60 |
+
return model
|
| 61 |
+
|
| 62 |
+
def _prepare_vocab(self):
|
| 63 |
+
with open(self.args["vocab"], 'rt') as f:
|
| 64 |
+
for line in f:
|
| 65 |
+
self.args["character"] += [line.rstrip()]
|
| 66 |
+
f.close()
|
| 67 |
+
|
| 68 |
+
if 'Attn' in self.args['Prediction']['name']:
|
| 69 |
+
self.converter = AttnLabelConverter(self.args["character"], self.device)
|
| 70 |
+
else:
|
| 71 |
+
self.converter = TFMLabelConverter(self.args["character"], self.device)
|
| 72 |
+
|
| 73 |
+
self.args["num_class"] = len(self.converter.character)
|
| 74 |
+
|
| 75 |
+
def _preprocess(self, image: Image.Image):
|
| 76 |
+
img_tensor = resize(self.resizer, image, self.args)
|
| 77 |
+
return img_tensor
|
| 78 |
+
|
| 79 |
+
def _postprocess(self, s: str):
|
| 80 |
+
text_reg = r'(\\(operatorname|mathrm|mathbf|mathsf|mathit|mathfrak|mathnormal)\s?\*? {.*?})'
|
| 81 |
+
letter = '[a-zA-Z]'
|
| 82 |
+
noletter = '[\W_^\d]'
|
| 83 |
+
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
|
| 84 |
+
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
|
| 85 |
+
news = s
|
| 86 |
+
|
| 87 |
+
for space in ["hspace", "vspace"]:
|
| 88 |
+
match = re.finditer(space + " {(.*?)}", news)
|
| 89 |
+
if match:
|
| 90 |
+
new_l = ""
|
| 91 |
+
last = 0
|
| 92 |
+
for m in match:
|
| 93 |
+
new_l = new_l + news[last:m.start(1)] + m.group(1).replace(" ", "")
|
| 94 |
+
last = m.end(1)
|
| 95 |
+
new_l = new_l + news[last:]
|
| 96 |
+
news = new_l
|
| 97 |
+
|
| 98 |
+
return news
|
| 99 |
+
|
| 100 |
+
def __call__(self, image: Image.Image, name=None, *arg: Any, **kwargs):
|
| 101 |
+
assert image.mode == 'RGB', 'input image must be RGB image'
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
img_tensor = self._preprocess(image).to(self.device)
|
| 104 |
+
text_for_pred = torch.LongTensor(1, self.args["batch_max_length"] + 1).fill_(0).to(self.device)
|
| 105 |
+
preds_index, _, _ = self.model(img_tensor, text_for_pred, is_train=False, is_test=True)
|
| 106 |
+
pred_str = self.converter.decode(preds_index, self.args.get('token_level', 'word'))[0]
|
| 107 |
+
|
| 108 |
+
pred_EOS = pred_str.find('[s]')
|
| 109 |
+
pred_str = pred_str[:pred_EOS]
|
| 110 |
+
|
| 111 |
+
process_str = self._postprocess(pred_str)
|
| 112 |
+
|
| 113 |
+
return process_str
|
HybridViT/recognizers/__init__.py
ADDED
|
File without changes
|
HybridViT/recognizers/build_feat.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from ..module.component.feature_extractor.clova_impl import VGG_FeatureExtractor, ResNet_FeatureExtractor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FeatExtractorBuilder(nn.Module):
|
| 7 |
+
def __init__(self, flow: dict, config):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.config = config
|
| 10 |
+
self.flow = flow
|
| 11 |
+
self.feat_name = flow['Feat']
|
| 12 |
+
|
| 13 |
+
if self.feat_name != 'None':
|
| 14 |
+
mean_height = config['FeatureExtraction']['params'].pop('mean_height', True)
|
| 15 |
+
|
| 16 |
+
if self.feat_name == 'VGG':
|
| 17 |
+
self.FeatureExtraction = VGG_FeatureExtractor(**config['FeatureExtraction']['params'])
|
| 18 |
+
self.FeatureExtraction_output = config['FeatureExtraction']['params']['output_channel']
|
| 19 |
+
elif self.feat_name == 'ResNet':
|
| 20 |
+
self.FeatureExtraction = ResNet_FeatureExtractor(**config['FeatureExtraction']['params'])
|
| 21 |
+
self.FeatureExtraction_output = config['FeatureExtraction']['params']['output_channel']
|
| 22 |
+
|
| 23 |
+
if mean_height:
|
| 24 |
+
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
|
| 25 |
+
else:
|
| 26 |
+
self.proj_feat = nn.Linear(self.FeatureExtraction_output*3, self.FeatureExtraction_output)
|
| 27 |
+
else:
|
| 28 |
+
if flow['Seq'] not in ['ViT', 'MS_ViT', 'MS_ViTV2', 'MS_ViTV3', 'ViG']:
|
| 29 |
+
raise Exception('No FeatureExtraction module specified')
|
| 30 |
+
else:
|
| 31 |
+
self.FeatureExtraction = nn.Identity()
|
| 32 |
+
|
| 33 |
+
def forward(self, input):
|
| 34 |
+
visual_feature = self.FeatureExtraction(input)
|
| 35 |
+
|
| 36 |
+
if self.flow['Seq'] in ['BiLSTM', 'BiLSTM_3L']:
|
| 37 |
+
if hasattr(self, 'AdaptiveAvgPool'):
|
| 38 |
+
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, 1]
|
| 39 |
+
visual_feature = visual_feature.squeeze(3)
|
| 40 |
+
else:
|
| 41 |
+
visual_feature = visual_feature.permute(0, 3, 1, 2)
|
| 42 |
+
visual_feature = visual_feature.flatten(start_dim=-2) # [b, c, h, w] -> [b, w, c*h]
|
| 43 |
+
visual_feature = self.proj_feat(visual_feature)
|
| 44 |
+
|
| 45 |
+
return visual_feature
|
HybridViT/recognizers/build_model.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from .build_feat import FeatExtractorBuilder
|
| 3 |
+
from .build_seq import SeqModelingBuilder
|
| 4 |
+
from .build_pred import PredictBuilder
|
| 5 |
+
|
| 6 |
+
class Model(nn.Module):
|
| 7 |
+
def __init__(self, opt):
|
| 8 |
+
super(Model, self).__init__()
|
| 9 |
+
self.opt = opt
|
| 10 |
+
|
| 11 |
+
stages = {
|
| 12 |
+
'Feat': opt['FeatureExtraction']['name'],
|
| 13 |
+
'Seq': opt['SequenceModeling']['name'],
|
| 14 |
+
'Pred': opt['Prediction']['name'],
|
| 15 |
+
}
|
| 16 |
+
self.stages = stages
|
| 17 |
+
if stages['Seq'].__contains__("Vi"): assert stages['Feat'] == 'None'
|
| 18 |
+
|
| 19 |
+
""" FeatureExtraction """
|
| 20 |
+
self.featextractor = FeatExtractorBuilder(stages, opt)
|
| 21 |
+
FeatureExtraction_output = getattr(self.featextractor, 'FeatureExtraction_output', None)
|
| 22 |
+
|
| 23 |
+
""" Sequence modeling"""
|
| 24 |
+
self.seqmodeler = SeqModelingBuilder(stages, opt, FeatureExtraction_output)
|
| 25 |
+
SequenceModeling_output = getattr(self.seqmodeler, 'SequenceModeling_output', None)
|
| 26 |
+
|
| 27 |
+
""" Prediction """
|
| 28 |
+
self.predicter = PredictBuilder(stages, opt, SequenceModeling_output)
|
| 29 |
+
|
| 30 |
+
def forward_encoder(self, input, *args, **kwargs):
|
| 31 |
+
""" Feature extraction stage """
|
| 32 |
+
visual_feature = self.featextractor(input)
|
| 33 |
+
""" Sequence modeling stage """
|
| 34 |
+
contextual_feature, output_shape , feat_pad = self.seqmodeler(visual_feature, *args, **kwargs)
|
| 35 |
+
return contextual_feature, output_shape, feat_pad
|
| 36 |
+
|
| 37 |
+
def forward_decoder(self, contextual_feature, text, is_train=True,
|
| 38 |
+
is_test=False, rtl_text=None
|
| 39 |
+
):
|
| 40 |
+
""" Prediction stage """
|
| 41 |
+
prediction, logits, decoder_attn, addition_outputs = self.predicter(contextual_feature, text, is_train, is_test, rtl_text)
|
| 42 |
+
|
| 43 |
+
return prediction, logits, decoder_attn, addition_outputs
|
| 44 |
+
|
| 45 |
+
def forward(self, input, text, is_train=True, is_test=False, rtl_text=None):
|
| 46 |
+
contextual_feature, output_shape, feat_pad = self.forward_encoder(input)
|
| 47 |
+
prediction, logits, decoder_attn, addition_outputs = self.forward_decoder(
|
| 48 |
+
contextual_feature, text=text, is_train=is_train, is_test=is_test, rtl_text=rtl_text
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if decoder_attn is not None and output_shape is not None:
|
| 52 |
+
if self.stages['Pred'] == 'Attn' and self.stages['Seq'] == 'ViT':
|
| 53 |
+
decoder_attn = decoder_attn[:, 1:]
|
| 54 |
+
decoder_attn = decoder_attn.reshape(-1, output_shape[0], output_shape[1])
|
| 55 |
+
|
| 56 |
+
addition_outputs.update(
|
| 57 |
+
{
|
| 58 |
+
'decoder_attn': decoder_attn,
|
| 59 |
+
'feat_width': output_shape[0],
|
| 60 |
+
'feat_height': output_shape[1],
|
| 61 |
+
'feat_pad': feat_pad,
|
| 62 |
+
}
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
return prediction, logits, addition_outputs
|
| 66 |
+
|
| 67 |
+
if __name__ == '__main__':
|
| 68 |
+
import yaml
|
| 69 |
+
import torch
|
| 70 |
+
|
| 71 |
+
with open('/media/huynhtruc0309/DATA/Math_Expression/my_source/Math_Recognition/config/train/paper_experiments/report_paper/best_model_noaugment/experiment_1806.yaml', 'r') as f:
|
| 72 |
+
config = yaml.safe_load(f)
|
| 73 |
+
|
| 74 |
+
config['num_class'] = 499
|
| 75 |
+
config['device'] = 'cpu'
|
| 76 |
+
model = Model(config)
|
| 77 |
+
|
| 78 |
+
a = torch.rand(1, 1, 32, 224)
|
| 79 |
+
|
| 80 |
+
output = model(a)
|
| 81 |
+
|
| 82 |
+
print('pred', output[0].shape)
|
HybridViT/recognizers/build_pred.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from ..module.component.prediction_head import (Attention,
|
| 4 |
+
AttentionV2,
|
| 5 |
+
TransformerPrediction
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PredictBuilder(nn.Module):
|
| 10 |
+
def __init__(self, flow, config, SequenceModeling_output):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.flow =flow
|
| 13 |
+
self.config=config
|
| 14 |
+
if flow['Pred'] == 'CTC':
|
| 15 |
+
self.Prediction = nn.Linear(SequenceModeling_output, config['num_class'])
|
| 16 |
+
|
| 17 |
+
elif flow['Pred'] == 'Attn':
|
| 18 |
+
config['Prediction']['params']['num_classes'] = config['num_class']
|
| 19 |
+
config['Prediction']['params']['device'] = config['device']
|
| 20 |
+
self.Prediction = Attention(
|
| 21 |
+
**config['Prediction']['params']
|
| 22 |
+
)
|
| 23 |
+
elif flow['Pred'] == 'Attnv2':
|
| 24 |
+
config['Prediction']['params']['num_classes'] = config['num_class']
|
| 25 |
+
config['Prediction']['params']['device'] = config['device']
|
| 26 |
+
self.Prediction = AttentionV2(
|
| 27 |
+
**config['Prediction']['params']
|
| 28 |
+
)
|
| 29 |
+
elif flow['Pred'] == 'Multistage_Attn':
|
| 30 |
+
config['Prediction']['params']['num_classes'] = config['num_class']
|
| 31 |
+
config['Prediction']['params']['device'] = config['device']
|
| 32 |
+
self.Prediction = AttentionV2(
|
| 33 |
+
**config['Prediction']['params']
|
| 34 |
+
)
|
| 35 |
+
elif flow['Pred'] == 'TFM':
|
| 36 |
+
config['Prediction']['params']['num_classes'] = config['num_class']
|
| 37 |
+
config['Prediction']['params']['device'] = config['device']
|
| 38 |
+
self.Prediction = TransformerPrediction(
|
| 39 |
+
**config['Prediction']['params']
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError('Prediction name is not suppported')
|
| 43 |
+
|
| 44 |
+
def forward(self, contextual_feature, text, is_train=True, is_test=False, rtl_text=None):
|
| 45 |
+
beam_size = self.config.get('beam_size', 1)
|
| 46 |
+
|
| 47 |
+
addition_outputs = {}
|
| 48 |
+
decoder_attn = None
|
| 49 |
+
|
| 50 |
+
if self.flow['Pred'] == 'CTC':
|
| 51 |
+
prediction = self.Prediction(contextual_feature.contiguous())
|
| 52 |
+
|
| 53 |
+
elif self.flow['Pred'] in ['Attn', 'Attnv2']:
|
| 54 |
+
prediction, logits, decoder_attn = self.Prediction(beam_size, contextual_feature.contiguous(), text, is_train=is_train,
|
| 55 |
+
is_test=is_test, batch_max_length=self.config['batch_max_length'])
|
| 56 |
+
|
| 57 |
+
elif self.flow['Pred'] == 'TFM':
|
| 58 |
+
prediction, logits = self.Prediction(beam_size, contextual_feature.contiguous(), text, is_test)
|
| 59 |
+
self.Prediction.reset_beam()
|
| 60 |
+
|
| 61 |
+
return prediction, logits, decoder_attn, addition_outputs
|
HybridViT/recognizers/build_seq.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from ..module.component.seq_modeling import BidirectionalLSTM, create_vit_modeling
|
| 3 |
+
from ..module.component.seq_modeling.bilstm import BiLSTM_Seq_Modeling
|
| 4 |
+
from ..module.component.common import GatedSum
|
| 5 |
+
from ..module.component.common import PositionalEncoding2D, PositionalEncoding1D
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SeqModelingBuilder(nn.Module):
|
| 9 |
+
def __init__(self, flow: dict, config, FeatureExtraction_output):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.config = config
|
| 12 |
+
self.flow = flow
|
| 13 |
+
|
| 14 |
+
if flow['Seq'] == 'BiLSTM':
|
| 15 |
+
hidden_size = config['SequenceModeling']['params']['hidden_size']
|
| 16 |
+
use_pos_enc = config['SequenceModeling']['params'].get('pos_enc', False)
|
| 17 |
+
|
| 18 |
+
if use_pos_enc:
|
| 19 |
+
self.image_positional_encoder = PositionalEncoding1D(hidden_size)
|
| 20 |
+
self.gated = GatedSum(hidden_size)
|
| 21 |
+
|
| 22 |
+
self.SequenceModeling = nn.Sequential(
|
| 23 |
+
BidirectionalLSTM(FeatureExtraction_output, hidden_size, hidden_size),
|
| 24 |
+
BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
|
| 25 |
+
|
| 26 |
+
self.SequenceModeling_output = hidden_size
|
| 27 |
+
|
| 28 |
+
elif flow['Seq'] == 'BiLSTM_3L':
|
| 29 |
+
hidden_size = config['SequenceModeling']['params']['hidden_size']
|
| 30 |
+
self.SequenceModeling = BiLSTM_Seq_Modeling(3, FeatureExtraction_output, hidden_size, hidden_size)
|
| 31 |
+
|
| 32 |
+
self.SequenceModeling_output = hidden_size
|
| 33 |
+
|
| 34 |
+
elif flow['Seq'] == 'ViT':
|
| 35 |
+
assert config['max_dimension'] is not None, "ViT encoder require exact height or max height and max width"
|
| 36 |
+
self.SequenceModeling = create_vit_modeling(config)
|
| 37 |
+
else:
|
| 38 |
+
print('No SequenceModeling module specified')
|
| 39 |
+
if flow['Pred'] == 'TFM':
|
| 40 |
+
self.image_positional_encoder = PositionalEncoding2D(FeatureExtraction_output)
|
| 41 |
+
|
| 42 |
+
self.SequenceModeling_output = FeatureExtraction_output
|
| 43 |
+
|
| 44 |
+
def forward(self, visual_feature, *args, **kwargs):
|
| 45 |
+
output_shape = None
|
| 46 |
+
pad_info = None
|
| 47 |
+
|
| 48 |
+
if self.flow['Seq'] in ['BiLSTM', 'BiLSTM_3L']:
|
| 49 |
+
contextual_feature = self.SequenceModeling(visual_feature)
|
| 50 |
+
|
| 51 |
+
if hasattr(self, 'image_positional_encoder'):
|
| 52 |
+
assert len(contextual_feature.shape) == 3
|
| 53 |
+
contextual_feature_1 = self.image_positional_encoder(visual_feature.permute(1, 0, 2))
|
| 54 |
+
contextual_feature_1 = contextual_feature_1.permute(1, 0, 2)
|
| 55 |
+
contextual_feature = self.gated(contextual_feature_1, contextual_feature)
|
| 56 |
+
|
| 57 |
+
elif self.flow['Seq'] == 'ViT':
|
| 58 |
+
contextual_feature, pad_info, _ = self.SequenceModeling(visual_feature)
|
| 59 |
+
|
| 60 |
+
return contextual_feature, output_shape, pad_info
|
HybridViT/resizer.py
ADDED
|
File without changes
|
ScanSSD/IOU_lib/BoundingBox.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .iou_utils import *
|
| 2 |
+
|
| 3 |
+
class BoundingBox:
|
| 4 |
+
def __init__(self,
|
| 5 |
+
imageName,
|
| 6 |
+
classId,
|
| 7 |
+
x,
|
| 8 |
+
y,
|
| 9 |
+
w,
|
| 10 |
+
h,
|
| 11 |
+
typeCoordinates=CoordinatesType.Absolute,
|
| 12 |
+
imgSize=None,
|
| 13 |
+
bbType=BBType.GroundTruth,
|
| 14 |
+
classConfidence=None,
|
| 15 |
+
format=BBFormat.XYWH):
|
| 16 |
+
"""Constructor.
|
| 17 |
+
Args:
|
| 18 |
+
imageName: String representing the image name.
|
| 19 |
+
classId: String value representing class id.
|
| 20 |
+
x: Float value representing the X upper-left coordinate of the bounding box.
|
| 21 |
+
y: Float value representing the Y upper-left coordinate of the bounding box.
|
| 22 |
+
w: Float value representing the width bounding box.
|
| 23 |
+
h: Float value representing the height bounding box.
|
| 24 |
+
typeCoordinates: (optional) Enum (Relative or Absolute) represents if the bounding box
|
| 25 |
+
coordinates (x,y,w,h) are absolute or relative to size of the image. Default:'Absolute'.
|
| 26 |
+
imgSize: (optional) 2D vector (width, height)=>(int, int) represents the size of the
|
| 27 |
+
image of the bounding box. If typeCoordinates is 'Relative', imgSize is required.
|
| 28 |
+
bbType: (optional) Enum (Groundtruth or Detection) identifies if the bounding box
|
| 29 |
+
represents a ground truth or a detection. If it is a detection, the classConfidence has
|
| 30 |
+
to be informed.
|
| 31 |
+
classConfidence: (optional) Float value representing the confidence of the detected
|
| 32 |
+
class. If detectionType is Detection, classConfidence needs to be informed.
|
| 33 |
+
format: (optional) Enum (BBFormat.XYWH or BBFormat.XYX2Y2) indicating the format of the
|
| 34 |
+
coordinates of the bounding boxes. BBFormat.XYWH: <left> <top> <width> <height>
|
| 35 |
+
BBFormat.XYX2Y2: <left> <top> <right> <bottom>.
|
| 36 |
+
"""
|
| 37 |
+
self._imageName = imageName
|
| 38 |
+
self._typeCoordinates = typeCoordinates
|
| 39 |
+
if typeCoordinates == CoordinatesType.Relative and imgSize is None:
|
| 40 |
+
raise IOError(
|
| 41 |
+
'Parameter \'imgSize\' is required. It is necessary to inform the image size.')
|
| 42 |
+
if bbType == BBType.Detected and classConfidence is None:
|
| 43 |
+
raise IOError(
|
| 44 |
+
'For bbType=\'Detection\', it is necessary to inform the classConfidence value.')
|
| 45 |
+
# if classConfidence != None and (classConfidence < 0 or classConfidence > 1):
|
| 46 |
+
# raise IOError('classConfidence value must be a real value between 0 and 1. Value: %f' %
|
| 47 |
+
# classConfidence)
|
| 48 |
+
|
| 49 |
+
self._classConfidence = classConfidence
|
| 50 |
+
self._bbType = bbType
|
| 51 |
+
self._classId = classId
|
| 52 |
+
self._format = format
|
| 53 |
+
|
| 54 |
+
# If relative coordinates, convert to absolute values
|
| 55 |
+
# For relative coords: (x,y,w,h)=(X_center/img_width , Y_center/img_height)
|
| 56 |
+
if (typeCoordinates == CoordinatesType.Relative):
|
| 57 |
+
(self._x, self._y, self._w, self._h) = convertToAbsoluteValues(imgSize, (x, y, w, h))
|
| 58 |
+
self._width_img = imgSize[0]
|
| 59 |
+
self._height_img = imgSize[1]
|
| 60 |
+
if format == BBFormat.XYWH:
|
| 61 |
+
self._x2 = self._w
|
| 62 |
+
self._y2 = self._h
|
| 63 |
+
self._w = self._x2 - self._x
|
| 64 |
+
self._h = self._y2 - self._y
|
| 65 |
+
else:
|
| 66 |
+
raise IOError(
|
| 67 |
+
'For relative coordinates, the format must be XYWH (x,y,width,height)')
|
| 68 |
+
# For absolute coords: (x,y,w,h)=real bb coords
|
| 69 |
+
else:
|
| 70 |
+
self._x = x
|
| 71 |
+
self._y = y
|
| 72 |
+
if format == BBFormat.XYWH:
|
| 73 |
+
self._w = w
|
| 74 |
+
self._h = h
|
| 75 |
+
self._x2 = self._x + self._w
|
| 76 |
+
self._y2 = self._y + self._h
|
| 77 |
+
else: # format == BBFormat.XYX2Y2: <left> <top> <right> <bottom>.
|
| 78 |
+
self._x2 = w
|
| 79 |
+
self._y2 = h
|
| 80 |
+
self._w = self._x2 - self._x
|
| 81 |
+
self._h = self._y2 - self._y
|
| 82 |
+
if imgSize is None:
|
| 83 |
+
self._width_img = None
|
| 84 |
+
self._height_img = None
|
| 85 |
+
else:
|
| 86 |
+
self._width_img = imgSize[0]
|
| 87 |
+
self._height_img = imgSize[1]
|
| 88 |
+
|
| 89 |
+
def __str__(self):
|
| 90 |
+
return "{} {} {} {} {}".format(self.getImageName(), self._x, self._y, self._w, self._h)
|
| 91 |
+
|
| 92 |
+
def getAbsoluteBoundingBox(self, format=BBFormat.XYWH):
|
| 93 |
+
if format == BBFormat.XYWH:
|
| 94 |
+
return (self._x, self._y, self._w, self._h)
|
| 95 |
+
elif format == BBFormat.XYX2Y2:
|
| 96 |
+
return (self._x, self._y, self._x2, self._y2)
|
| 97 |
+
|
| 98 |
+
def getRelativeBoundingBox(self, imgSize=None):
|
| 99 |
+
if imgSize is None and self._width_img is None and self._height_img is None:
|
| 100 |
+
raise IOError(
|
| 101 |
+
'Parameter \'imgSize\' is required. It is necessary to inform the image size.')
|
| 102 |
+
if imgSize is None:
|
| 103 |
+
return convertToRelativeValues((imgSize[0], imgSize[1]),
|
| 104 |
+
(self._x, self._y, self._w, self._h))
|
| 105 |
+
else:
|
| 106 |
+
return convertToRelativeValues((self._width_img, self._height_img),
|
| 107 |
+
(self._x, self._y, self._w, self._h))
|
| 108 |
+
|
| 109 |
+
def getImageName(self):
|
| 110 |
+
return self._imageName
|
| 111 |
+
|
| 112 |
+
def getConfidence(self):
|
| 113 |
+
return self._classConfidence
|
| 114 |
+
|
| 115 |
+
def getFormat(self):
|
| 116 |
+
return self._format
|
| 117 |
+
|
| 118 |
+
def getClassId(self):
|
| 119 |
+
return self._classId
|
| 120 |
+
|
| 121 |
+
def getImageSize(self):
|
| 122 |
+
return (self._width_img, self._height_img)
|
| 123 |
+
|
| 124 |
+
def getCoordinatesType(self):
|
| 125 |
+
return self._typeCoordinates
|
| 126 |
+
|
| 127 |
+
def getBBType(self):
|
| 128 |
+
return self._bbType
|
| 129 |
+
|
| 130 |
+
@staticmethod
|
| 131 |
+
def compare(det1, det2):
|
| 132 |
+
det1BB = det1.getAbsoluteBoundingBox()
|
| 133 |
+
det1ImgSize = det1.getImageSize()
|
| 134 |
+
det2BB = det2.getAbsoluteBoundingBox()
|
| 135 |
+
det2ImgSize = det2.getImageSize()
|
| 136 |
+
|
| 137 |
+
if det1.getClassId() == det2.getClassId() and \
|
| 138 |
+
det1.classConfidence == det2.classConfidenc() and \
|
| 139 |
+
det1BB[0] == det2BB[0] and \
|
| 140 |
+
det1BB[1] == det2BB[1] and \
|
| 141 |
+
det1BB[2] == det2BB[2] and \
|
| 142 |
+
det1BB[3] == det2BB[3] and \
|
| 143 |
+
det1ImgSize[0] == det1ImgSize[0] and \
|
| 144 |
+
det2ImgSize[1] == det2ImgSize[1]:
|
| 145 |
+
return True
|
| 146 |
+
return False
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
def clone(boundingBox):
|
| 150 |
+
absBB = boundingBox.getAbsoluteBoundingBox(format=BBFormat.XYWH)
|
| 151 |
+
# return (self._x,self._y,self._x2,self._y2)
|
| 152 |
+
newBoundingBox = BoundingBox(
|
| 153 |
+
boundingBox.getImageName(),
|
| 154 |
+
boundingBox.getClassId(),
|
| 155 |
+
absBB[0],
|
| 156 |
+
absBB[1],
|
| 157 |
+
absBB[2],
|
| 158 |
+
absBB[3],
|
| 159 |
+
typeCoordinates=boundingBox.getCoordinatesType(),
|
| 160 |
+
imgSize=boundingBox.getImageSize(),
|
| 161 |
+
bbType=boundingBox.getBBType(),
|
| 162 |
+
classConfidence=boundingBox.getConfidence(),
|
| 163 |
+
format=BBFormat.XYWH)
|
| 164 |
+
return newBoundingBox
|
ScanSSD/IOU_lib/Evaluator.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###########################################################################################
|
| 2 |
+
# #
|
| 3 |
+
# Evaluator class: Implements the most popular metrics for object detection #
|
| 4 |
+
# #
|
| 5 |
+
# Developed by: Rafael Padilla (rafael.padilla@smt.ufrj.br) #
|
| 6 |
+
# SMT - Signal Multimedia and Telecommunications Lab #
|
| 7 |
+
# COPPE - Universidade Federal do Rio de Janeiro #
|
| 8 |
+
# Last modification: Oct 9th 2018 #
|
| 9 |
+
###########################################################################################
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
from collections import Counter
|
| 14 |
+
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
from .BoundingBox import *
|
| 19 |
+
from .iou_utils import *
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Evaluator:
|
| 23 |
+
|
| 24 |
+
# For each detections, calculate IOU with reference
|
| 25 |
+
@staticmethod
|
| 26 |
+
def _getAllIOUs(reference, detections):
|
| 27 |
+
ret = []
|
| 28 |
+
bbReference = reference.getAbsoluteBoundingBox(BBFormat.XYX2Y2)
|
| 29 |
+
# img = np.zeros((200,200,3), np.uint8)
|
| 30 |
+
for d in detections:
|
| 31 |
+
bb = d.getAbsoluteBoundingBox(BBFormat.XYX2Y2)
|
| 32 |
+
iou = Evaluator.iou(bbReference, bb)
|
| 33 |
+
# Show blank image with the bounding boxes
|
| 34 |
+
# img = add_bb_into_image(img, d, color=(255,0,0), thickness=2, label=None)
|
| 35 |
+
# img = add_bb_into_image(img, reference, color=(0,255,0), thickness=2, label=None)
|
| 36 |
+
ret.append((iou, reference, d)) # iou, reference, detection
|
| 37 |
+
# cv2.imshow("comparing",img)
|
| 38 |
+
# cv2.waitKey(0)
|
| 39 |
+
# cv2.destroyWindow("comparing")
|
| 40 |
+
return sorted(ret, key=lambda i: i[0], reverse=True) # sort by iou (from highest to lowest)
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def iou(boxA, boxB):
|
| 44 |
+
# if boxes dont intersect
|
| 45 |
+
if Evaluator._boxesIntersect(boxA, boxB) is False:
|
| 46 |
+
return 0
|
| 47 |
+
interArea = Evaluator._getIntersectionArea(boxA, boxB)
|
| 48 |
+
union = Evaluator._getUnionAreas(boxA, boxB, interArea=interArea)
|
| 49 |
+
# intersection over union
|
| 50 |
+
iou = interArea / union
|
| 51 |
+
assert iou >= 0
|
| 52 |
+
return iou
|
| 53 |
+
|
| 54 |
+
# boxA = (Ax1,Ay1,Ax2,Ay2)
|
| 55 |
+
# boxB = (Bx1,By1,Bx2,By2)
|
| 56 |
+
@staticmethod
|
| 57 |
+
def _boxesIntersect(boxA, boxB):
|
| 58 |
+
if boxA[0] > boxB[2]:
|
| 59 |
+
return False # boxA is right of boxB
|
| 60 |
+
if boxB[0] > boxA[2]:
|
| 61 |
+
return False # boxA is left of boxB
|
| 62 |
+
if boxA[3] < boxB[1]:
|
| 63 |
+
return False # boxA is above boxB
|
| 64 |
+
if boxA[1] > boxB[3]:
|
| 65 |
+
return False # boxA is below boxB
|
| 66 |
+
return True
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def _getIntersectionArea(boxA, boxB):
|
| 70 |
+
xA = max(boxA[0], boxB[0])
|
| 71 |
+
yA = max(boxA[1], boxB[1])
|
| 72 |
+
xB = min(boxA[2], boxB[2])
|
| 73 |
+
yB = min(boxA[3], boxB[3])
|
| 74 |
+
# intersection area
|
| 75 |
+
return (xB - xA + 1) * (yB - yA + 1)
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def _getUnionAreas(boxA, boxB, interArea=None):
|
| 79 |
+
area_A = Evaluator._getArea(boxA)
|
| 80 |
+
area_B = Evaluator._getArea(boxB)
|
| 81 |
+
if interArea is None:
|
| 82 |
+
interArea = Evaluator._getIntersectionArea(boxA, boxB)
|
| 83 |
+
return float(area_A + area_B - interArea)
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def _getArea(box):
|
| 87 |
+
return (box[2] - box[0] + 1) * (box[3] - box[1] + 1)
|
ScanSSD/IOU_lib/IOUevaluater.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from zipfile import ZipFile
|
| 2 |
+
import os
|
| 3 |
+
from .Evaluator import *
|
| 4 |
+
from utils import *
|
| 5 |
+
import copy
|
| 6 |
+
import argparse
|
| 7 |
+
import sys
|
| 8 |
+
import ntpath
|
| 9 |
+
#import cStringIO
|
| 10 |
+
from io import BytesIO
|
| 11 |
+
import shutil
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def read_file(filename, bboxes, flag):
|
| 15 |
+
'''
|
| 16 |
+
Parses the input .csv file into map where key as page number and value as a list of bounding box objects
|
| 17 |
+
corresponding to each math region in the file.
|
| 18 |
+
:param filename: .csv file containing math regions
|
| 19 |
+
:param bboxes: Map<page_num, List<bboxes>>
|
| 20 |
+
:return:
|
| 21 |
+
'''
|
| 22 |
+
fh1 = open(filename, "r")
|
| 23 |
+
prev_page = -1
|
| 24 |
+
counter = 1
|
| 25 |
+
for line in fh1:
|
| 26 |
+
line = line.replace("\n", "")
|
| 27 |
+
if line.replace(' ', '') == '':
|
| 28 |
+
continue
|
| 29 |
+
splitLine = line.split(",")
|
| 30 |
+
idClass = float(splitLine[0])
|
| 31 |
+
if prev_page == -1:
|
| 32 |
+
prev_page = idClass
|
| 33 |
+
else:
|
| 34 |
+
if idClass != prev_page:
|
| 35 |
+
counter = 1
|
| 36 |
+
prev_page = idClass
|
| 37 |
+
x = float(splitLine[1])
|
| 38 |
+
y = float(splitLine[2])
|
| 39 |
+
x2 = float(splitLine[3])
|
| 40 |
+
y2 = float(splitLine[4])
|
| 41 |
+
bb = BoundingBox(
|
| 42 |
+
flag+"_"+str(counter),
|
| 43 |
+
1,
|
| 44 |
+
x,
|
| 45 |
+
y,
|
| 46 |
+
x2,
|
| 47 |
+
y2,
|
| 48 |
+
CoordinatesType.Absolute, (200, 200),
|
| 49 |
+
BBType.GroundTruth,
|
| 50 |
+
format=BBFormat.XYX2Y2)
|
| 51 |
+
counter += 1
|
| 52 |
+
#print(counter)
|
| 53 |
+
if idClass not in bboxes:
|
| 54 |
+
bboxes[idClass] = []
|
| 55 |
+
bboxes[idClass].append(bb)
|
| 56 |
+
|
| 57 |
+
fh1.close()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def extract_zipfile(zip_filename, target_dir):
|
| 61 |
+
'''
|
| 62 |
+
Extract zip file into the target directory
|
| 63 |
+
:param zip_filename: full-file-path of the zip-file
|
| 64 |
+
:param target_dir: target-dir to extract contents of zip-file
|
| 65 |
+
:return:
|
| 66 |
+
'''
|
| 67 |
+
with ZipFile(zip_filename, 'r') as zip:
|
| 68 |
+
# extracting all the files
|
| 69 |
+
print('Extracting all the files now...')
|
| 70 |
+
zip.extractall(target_dir)
|
| 71 |
+
print('Done!')
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def create_doc_bboxes_map(dir_path,flag):
|
| 75 |
+
'''
|
| 76 |
+
Reads all files recursively in directory path and and returns a map containing bboxes for each page in each math
|
| 77 |
+
file in directory.
|
| 78 |
+
:param dir_path: full directory path containing math files
|
| 79 |
+
:return: Map<PDF_name, Map<Page_number, List<BBoxes>>>
|
| 80 |
+
'''
|
| 81 |
+
pdf_bboxes_map = {}
|
| 82 |
+
|
| 83 |
+
for filename in os.listdir(dir_path):
|
| 84 |
+
full_filepath = os.path.join(dir_path, filename)
|
| 85 |
+
filename_key = os.path.splitext(os.path.basename(full_filepath))[0]
|
| 86 |
+
#print(full_filepath)
|
| 87 |
+
if (full_filepath.startswith(".")) or (not (full_filepath.endswith(".csv") or full_filepath.endswith(".math"))):
|
| 88 |
+
continue
|
| 89 |
+
bboxes_map = {}
|
| 90 |
+
|
| 91 |
+
if os.path.isdir(full_filepath):
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
read_file(full_filepath, bboxes_map,flag)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print('exception occurred in reading file',full_filepath, str(e))
|
| 98 |
+
|
| 99 |
+
#if len(bboxes_map)==0:
|
| 100 |
+
# raise ValueError("Empty ground truths file or not in valid format")
|
| 101 |
+
pdf_bboxes_map[filename_key] = copy.deepcopy(bboxes_map)
|
| 102 |
+
|
| 103 |
+
return pdf_bboxes_map
|
| 104 |
+
|
| 105 |
+
def unique_values(input_dict):
|
| 106 |
+
#return ground truth boxes that have same det boxes
|
| 107 |
+
pred_list=[]
|
| 108 |
+
repair_keys=[]
|
| 109 |
+
for value in input_dict.values():
|
| 110 |
+
if value[1] in pred_list: #preds.append(value)
|
| 111 |
+
gts=[k for k,v in input_dict.items() if v[1] == value[1]]
|
| 112 |
+
#print('pair length',len(gts))
|
| 113 |
+
repair_keys.append(gts)
|
| 114 |
+
pred_list.append(value[1])
|
| 115 |
+
|
| 116 |
+
return repair_keys
|
| 117 |
+
|
| 118 |
+
def generate_validpairs(pairs):
|
| 119 |
+
newpairs=[]
|
| 120 |
+
for pair in pairs:
|
| 121 |
+
if len(pair)>2:
|
| 122 |
+
for i in range(len(pair)-1):
|
| 123 |
+
newpair=(pair[i],pair[i+1])
|
| 124 |
+
if newpair not in newpairs:newpairs.append(newpair)
|
| 125 |
+
|
| 126 |
+
elif pair not in newpairs: newpairs.append(pair)
|
| 127 |
+
return newpairs
|
| 128 |
+
|
| 129 |
+
def fix_preds(input_dict,keyPairs,thre):
|
| 130 |
+
|
| 131 |
+
validPairs=generate_validpairs(keyPairs)
|
| 132 |
+
|
| 133 |
+
for pair in validPairs:
|
| 134 |
+
#check if both pair exists"
|
| 135 |
+
if pair[0] not in list(input_dict.keys()) or pair[1] not in list(input_dict.keys()):
|
| 136 |
+
continue
|
| 137 |
+
val0=input_dict[pair[0]][0]
|
| 138 |
+
val1=input_dict[pair[1]][0]
|
| 139 |
+
if val0>=val1: #change prediction for second pair
|
| 140 |
+
values=input_dict[pair[1]]
|
| 141 |
+
newprob=values[2][1]
|
| 142 |
+
if newprob<thre:
|
| 143 |
+
del input_dict[pair[1]]
|
| 144 |
+
continue
|
| 145 |
+
#update dict
|
| 146 |
+
input_dict[pair[1]]=newprob,values[3][1],values[2][1:],values[3][1:]
|
| 147 |
+
|
| 148 |
+
if val1>val0: #change prediction for first pair
|
| 149 |
+
values=input_dict[pair[0]]
|
| 150 |
+
newprob=values[2][1]
|
| 151 |
+
if newprob<thre:
|
| 152 |
+
del input_dict[pair[0]]
|
| 153 |
+
continue
|
| 154 |
+
#update dict
|
| 155 |
+
input_dict[pair[0]]=newprob,values[3][1],values[2][1:],values[3][1:]
|
| 156 |
+
|
| 157 |
+
return input_dict
|
| 158 |
+
|
| 159 |
+
def find_uni_pred(input_dict,thre):
|
| 160 |
+
# check if it is unique
|
| 161 |
+
pairs=unique_values(input_dict)
|
| 162 |
+
if pairs==[]:
|
| 163 |
+
return input_dict
|
| 164 |
+
|
| 165 |
+
while pairs:
|
| 166 |
+
output_dict=fix_preds(input_dict,pairs,thre)
|
| 167 |
+
pairs=unique_values(output_dict)
|
| 168 |
+
|
| 169 |
+
return output_dict
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def count_true_box(pred_dict,thre):
|
| 173 |
+
#remove predictions below thre from dict
|
| 174 |
+
for key in list(pred_dict.keys()):
|
| 175 |
+
max_prob=pred_dict[key][0]
|
| 176 |
+
if max_prob<thre:
|
| 177 |
+
del pred_dict[key]
|
| 178 |
+
|
| 179 |
+
#check for 101 mapping
|
| 180 |
+
final_dict=find_uni_pred(pred_dict,thre)
|
| 181 |
+
|
| 182 |
+
count=len(final_dict.keys())
|
| 183 |
+
return count,final_dict
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def IoU_page_bboxes(gt_page_bboxes_map, det_page_bboxes_map, pdf_name, outdir=None):
|
| 187 |
+
'''
|
| 188 |
+
Takes two maps containing page level bounding boxes for ground truth and detections for same PDF filename and
|
| 189 |
+
computes IoU for each BBox in a page in GT against all BBoxes in the same page in detections and returns them in
|
| 190 |
+
decreasing value of IoU. In this way it computes IoU for all pages in map.
|
| 191 |
+
|
| 192 |
+
:param gt_page_bboxes_map: Map<pageNum, List<bboxes>> for ground truth bboxes
|
| 193 |
+
:param det_page_bboxes_map: Map<pageNum, List<bboxes>> for detection bboxes
|
| 194 |
+
:return:
|
| 195 |
+
'''
|
| 196 |
+
evaluator = Evaluator()
|
| 197 |
+
|
| 198 |
+
correct_pred_coarse=0
|
| 199 |
+
correct_pred_fine=0
|
| 200 |
+
|
| 201 |
+
pdf_gt_boxes=0
|
| 202 |
+
pdf_det_boxes=0
|
| 203 |
+
|
| 204 |
+
coarse_keys = {}
|
| 205 |
+
fine_keys = {}
|
| 206 |
+
|
| 207 |
+
for page_num in gt_page_bboxes_map:
|
| 208 |
+
if page_num not in det_page_bboxes_map:
|
| 209 |
+
print('Detections not found for page', str(page_num + 1), ' in', pdf_name)
|
| 210 |
+
continue
|
| 211 |
+
gt_boxes = gt_page_bboxes_map[page_num]
|
| 212 |
+
det_boxes = det_page_bboxes_map[page_num]
|
| 213 |
+
|
| 214 |
+
pdf_gt_boxes+=len(gt_boxes)
|
| 215 |
+
pdf_det_boxes+=len(det_boxes)
|
| 216 |
+
|
| 217 |
+
pred_dict={}
|
| 218 |
+
for gt_box in gt_boxes:
|
| 219 |
+
ious = evaluator._getAllIOUs(gt_box, det_boxes)
|
| 220 |
+
preds=[]
|
| 221 |
+
labels=[]
|
| 222 |
+
for i in range(len(ious)):
|
| 223 |
+
preds.append(round(ious[i][0],2))
|
| 224 |
+
labels.append(ious[i][2].getImageName())
|
| 225 |
+
|
| 226 |
+
pred_dict[gt_box.getImageName()]=preds[0],labels[0],preds,labels
|
| 227 |
+
|
| 228 |
+
coarse,coarse_dict=count_true_box(copy.deepcopy(pred_dict),0.5)
|
| 229 |
+
fine,fine_dict=count_true_box(copy.deepcopy(pred_dict),0.75)
|
| 230 |
+
|
| 231 |
+
coarse_keys[page_num] = coarse_dict.keys()
|
| 232 |
+
fine_keys[page_num] = fine_dict.keys()
|
| 233 |
+
|
| 234 |
+
#count correct preds for coarse 0.5 and fine 0.75 in one page
|
| 235 |
+
correct_pred_coarse= correct_pred_coarse+coarse
|
| 236 |
+
correct_pred_fine= correct_pred_fine+fine
|
| 237 |
+
#write iou per page
|
| 238 |
+
if outdir:
|
| 239 |
+
out_file = open(os.path.join(outdir,pdf_name.split(".csv")[0]+"_"+str(page_num)+"_eval.txt"), "w")
|
| 240 |
+
out_file.write('#page num '+str(page_num)+", gt_box:"+str(len(gt_boxes))+
|
| 241 |
+
", pred_box:"+str(len(det_boxes))+"\n")
|
| 242 |
+
out_file.write('\n')
|
| 243 |
+
out_file.write('#COARSE DETECTION (iou>0.5):\n#number of correct prediction:'+ str(coarse)+ '\n#correctly detected:'+
|
| 244 |
+
str(list(coarse_dict.keys()))+'\n')
|
| 245 |
+
out_file.write('\n')
|
| 246 |
+
out_file.write('#FINE DETECTION (iou>0.75):\n#number of correct prediction:'+ str(fine)+ '\n#correctly detected:'+
|
| 247 |
+
str(list(fine_dict.keys()))+'\n')
|
| 248 |
+
out_file.write('\n')
|
| 249 |
+
out_file.write('#Sorted IOU scores for each GT box:\n')
|
| 250 |
+
for gt_box in gt_boxes:
|
| 251 |
+
ious = evaluator._getAllIOUs(gt_box, det_boxes)
|
| 252 |
+
out_file.write(gt_box.getImageName()+",")
|
| 253 |
+
for i in range(len(ious)-1):
|
| 254 |
+
out_file.write("("+str(round(ious[i][0],2))+" "+ str(ious[i][2].getImageName())+"),")
|
| 255 |
+
out_file.write( "("+str(round(ious[-1][0],2))+" "+ str(ious[-1][2].getImageName())+")\n" )
|
| 256 |
+
out_file.close()
|
| 257 |
+
|
| 258 |
+
return correct_pred_coarse, correct_pred_fine, pdf_gt_boxes, pdf_det_boxes, coarse_keys, fine_keys
|
| 259 |
+
|
| 260 |
+
def count_box(input_dict):
|
| 261 |
+
count=0
|
| 262 |
+
for pdf in input_dict.values():
|
| 263 |
+
for page in pdf.values():
|
| 264 |
+
count+=len(page)
|
| 265 |
+
|
| 266 |
+
return count
|
| 267 |
+
|
| 268 |
+
# Zip every uploading files
|
| 269 |
+
def archive_iou_txt(username, task_id, sub_id,userpath):
|
| 270 |
+
|
| 271 |
+
inputdir=os.path.join(userpath,'iouEval_stats')
|
| 272 |
+
|
| 273 |
+
if not os.path.exists(inputdir):
|
| 274 |
+
print('No txt file is generated for IOU evaluation')
|
| 275 |
+
pass
|
| 276 |
+
|
| 277 |
+
dest_uploader = 'IOU_stats_archive'
|
| 278 |
+
dest_uploader = os.path.join(userpath, dest_uploader)
|
| 279 |
+
|
| 280 |
+
if not os.path.exists(dest_uploader):
|
| 281 |
+
os.makedirs(dest_uploader)
|
| 282 |
+
|
| 283 |
+
zip_file_name = '/' + task_id + '_' + sub_id
|
| 284 |
+
shutil.make_archive(dest_uploader + zip_file_name, 'zip', inputdir)
|
| 285 |
+
|
| 286 |
+
# return '/media/' + dest_uploader
|
| 287 |
+
|
| 288 |
+
def write_html(gtFile,resultsFile,info,scores,destFile):
|
| 289 |
+
|
| 290 |
+
destFile.write('<html>')
|
| 291 |
+
destFile.write('<head><link rel="stylesheet" href="//maxcdn.bootstrapcdn.com/font-awesome/4.3.0/css/font-awesome.min.css"><link href="/static/css/bootstrap.min.css" rel="stylesheet"></head>')
|
| 292 |
+
destFile.write('<body>')
|
| 293 |
+
#writeCSS(destFile)
|
| 294 |
+
destFile.write ("<blockquote><b>CROHME 2019</b> <h1> Formula Detection Results ( TASK 3 )</h1><hr>")
|
| 295 |
+
destFile.write("<b>Submitted Files</b><ul><li><b>Output:</b> "+ ntpath.basename(resultsFile) +"</li>")
|
| 296 |
+
destFile.write ("<li><b>Ground-truth:</b> " + ntpath.basename(gtFile) + "</li></ul>")
|
| 297 |
+
if info['allGTbox'] == 0:
|
| 298 |
+
sys.stderr.write("Error : no sample in this GT list !\n")
|
| 299 |
+
exit(-1)
|
| 300 |
+
#all detection and gt boxes
|
| 301 |
+
destFile.write ("<p><b> Number of ground truth bounding boxes: </b>" + str(info['allGTbox']) + "<br /><b> Number of detected bounding boxes: </b>" + str(info['allDet']))
|
| 302 |
+
destFile.write ("<hr>")
|
| 303 |
+
#coarse results
|
| 304 |
+
destFile.write ("<p><b> **** Coarse Detection Results (IOU>0.5) ****</b><br />")
|
| 305 |
+
destFile.write ("<ul><li><b>"+str(scores['coarse_f']) + "</b> F-score</li>")
|
| 306 |
+
destFile.write ("<li>"+str(scores['coarse_pre']) + " Precision</li>")
|
| 307 |
+
destFile.write ("<li>"+str(scores['coarse_rec']) + " Recall</li></ul>")
|
| 308 |
+
destFile.write ("<b>" + str(info['correctDet_c']) + "</b> Number of correctly detected bounding boxes</p>")
|
| 309 |
+
destFile.write ("<hr>")
|
| 310 |
+
#fine results
|
| 311 |
+
destFile.write ("<p><b> **** Fine Detection Results (IOU>0.75) ****</b><br />")
|
| 312 |
+
destFile.write ("<ul><li><b>"+str(scores['fine_f']) + "</b> F-score</li>")
|
| 313 |
+
destFile.write ("<li>"+str(scores['fine_pre']) + " Precision</li>")
|
| 314 |
+
destFile.write ("<li>"+str(scores['fine_rec']) + " Recall</li></ul>")
|
| 315 |
+
destFile.write ("<b>" + str(info['correctDet_f']) + "</b> Number of correctly detected bounding boxes</p>")
|
| 316 |
+
destFile.write ("<hr>")
|
| 317 |
+
destFile.write('</body>')
|
| 318 |
+
destFile.write('</html>')
|
| 319 |
+
|
| 320 |
+
def pre_rec_calculate(count):
|
| 321 |
+
|
| 322 |
+
if count['allDet']==0:
|
| 323 |
+
print ('No detection boxes found')
|
| 324 |
+
scores={'fine_f':0,'coarse_f':0}
|
| 325 |
+
else:
|
| 326 |
+
pre_f=count['correctDet_f']/float(count['allDet'])
|
| 327 |
+
recall_f=count['correctDet_f']/float(count['allGTbox'])
|
| 328 |
+
if pre_f==0 and recall_f ==0:
|
| 329 |
+
f_f=0
|
| 330 |
+
else:
|
| 331 |
+
f_f=2*(pre_f*recall_f)/float(pre_f+recall_f)
|
| 332 |
+
|
| 333 |
+
pre_c=count['correctDet_c']/float(count['allDet'])
|
| 334 |
+
recall_c=count['correctDet_c']/float(count['allGTbox'])
|
| 335 |
+
if pre_c==0 and recall_c==0:
|
| 336 |
+
f_c=0
|
| 337 |
+
else:
|
| 338 |
+
f_c=2*(pre_c*recall_c)/float(pre_c+recall_c)
|
| 339 |
+
print('')
|
| 340 |
+
print('**** coarse result : threshold: 0.5 *****')
|
| 341 |
+
print(' f =',f_c,' precision =',pre_c,' recall =',recall_c)
|
| 342 |
+
print('')
|
| 343 |
+
print('**** fine result : threshold: 0.75 *****')
|
| 344 |
+
print(' f =',f_f,' precision =',pre_f,' recall =',recall_f)
|
| 345 |
+
|
| 346 |
+
scores={'fine_f':round(f_f,4),'fine_pre':round(pre_f,4),'fine_rec':round(recall_f,4),
|
| 347 |
+
'coarse_f':round(f_c,4),'coarse_pre':round(pre_c,4),'coarse_rec':round(recall_c,4)}
|
| 348 |
+
return scores
|
| 349 |
+
|
| 350 |
+
def IOUeval(ground_truth, detections, outdir=None): #,
|
| 351 |
+
|
| 352 |
+
keys=['allGTbox','correctDet_c','correctDet_f','allDet']
|
| 353 |
+
info=dict.fromkeys(keys,0)
|
| 354 |
+
|
| 355 |
+
gt_file_name = ground_truth
|
| 356 |
+
det_file_name = detections
|
| 357 |
+
|
| 358 |
+
#TODO : Mahshad change it to user directory
|
| 359 |
+
if outdir:
|
| 360 |
+
#outdir='IOU_eval_stats'
|
| 361 |
+
if os.path.exists(outdir):
|
| 362 |
+
shutil.rmtree(outdir)
|
| 363 |
+
os.makedirs(outdir)
|
| 364 |
+
|
| 365 |
+
gt_pdfs_bboxes_map = create_doc_bboxes_map(gt_file_name,'gt')
|
| 366 |
+
det_pdfs_bboxes_map = create_doc_bboxes_map(det_file_name,'det')
|
| 367 |
+
#count boxes
|
| 368 |
+
all_gtbox=count_box(gt_pdfs_bboxes_map)
|
| 369 |
+
all_detbox=count_box(det_pdfs_bboxes_map)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
info['allGTbox']=all_gtbox
|
| 373 |
+
info['allDet']=all_detbox
|
| 374 |
+
|
| 375 |
+
pdf_gt_bbs = 0
|
| 376 |
+
pdf_dt_bbs = 0
|
| 377 |
+
pdf_info = {}
|
| 378 |
+
pdf_calcs = {}
|
| 379 |
+
|
| 380 |
+
detailed_detections = {}
|
| 381 |
+
|
| 382 |
+
for pdf_name in gt_pdfs_bboxes_map:
|
| 383 |
+
if pdf_name not in det_pdfs_bboxes_map:
|
| 384 |
+
print('Detections not found for ',pdf_name)
|
| 385 |
+
continue
|
| 386 |
+
|
| 387 |
+
det_page_bboxes_map = det_pdfs_bboxes_map[pdf_name]
|
| 388 |
+
gt_page_bboxes_map = gt_pdfs_bboxes_map[pdf_name]
|
| 389 |
+
|
| 390 |
+
coarse_true_det,fine_true_det,pdf_gt_boxes,pdf_det_boxes,coarse_keys,fine_keys=\
|
| 391 |
+
IoU_page_bboxes(gt_page_bboxes_map, det_page_bboxes_map, pdf_name,outdir)
|
| 392 |
+
info['correctDet_c']=info['correctDet_c']+coarse_true_det
|
| 393 |
+
info['correctDet_f']=info['correctDet_f']+fine_true_det
|
| 394 |
+
|
| 395 |
+
pdf_info['correctDet_c']=coarse_true_det
|
| 396 |
+
pdf_info['correctDet_f']=fine_true_det
|
| 397 |
+
pdf_info['allGTbox']=pdf_gt_boxes
|
| 398 |
+
pdf_info['allDet']=pdf_det_boxes
|
| 399 |
+
|
| 400 |
+
print('For pdf: ', pdf_name)
|
| 401 |
+
pdf_calcs[pdf_name]=pre_rec_calculate(pdf_info)
|
| 402 |
+
detailed_detections[pdf_name] = [coarse_keys, fine_keys]
|
| 403 |
+
#print('Pdf score:',pdf_name, " --> ", pre_rec_calculate(pdf_info))
|
| 404 |
+
|
| 405 |
+
print('\n')
|
| 406 |
+
print(info)
|
| 407 |
+
scores=pre_rec_calculate(info)
|
| 408 |
+
|
| 409 |
+
print('\n PDF Level \n')
|
| 410 |
+
#print(pdf_calcs)
|
| 411 |
+
|
| 412 |
+
#{'fine_f': 0.7843, 'fine_pre': 0.7774, 'fine_rec': 0.7914, 'coarse_f': 0.902, 'coarse_pre': 0.894, 'coarse_rec': 0.9101}
|
| 413 |
+
for pdf_name in pdf_calcs:
|
| 414 |
+
print(pdf_name,'\t', pdf_calcs[pdf_name]['coarse_f'],'\t',pdf_calcs[pdf_name]['fine_f'])
|
| 415 |
+
|
| 416 |
+
#return corase and fine F-scores
|
| 417 |
+
return scores['coarse_f'],scores['fine_f'], detailed_detections
|
| 418 |
+
|
| 419 |
+
if __name__=='__main__':
|
| 420 |
+
|
| 421 |
+
parser = argparse.ArgumentParser()
|
| 422 |
+
parser.add_argument("--detections", type=str, required=True, help="detections file path")
|
| 423 |
+
parser.add_argument("--ground_truth", type=str, required=True, help="ground_truth file path")
|
| 424 |
+
args = parser.parse_args()
|
| 425 |
+
|
| 426 |
+
gt_file_name = args.ground_truth
|
| 427 |
+
det_file_name = args.detections
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
c_f,f_f=IOUeval(gt_file_name,det_file_name,outdir='IOU_scores_pages/')
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
|
ScanSSD/IOU_lib/__init__.py
ADDED
|
File without changes
|
ScanSSD/IOU_lib/iou_utils.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class CoordinatesType(Enum):
|
| 7 |
+
"""
|
| 8 |
+
Class representing if the coordinates are relative to the
|
| 9 |
+
image size or are absolute values.
|
| 10 |
+
|
| 11 |
+
Developed by: Rafael Padilla
|
| 12 |
+
Last modification: Apr 28 2018
|
| 13 |
+
"""
|
| 14 |
+
Relative = 1
|
| 15 |
+
Absolute = 2
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BBType(Enum):
|
| 19 |
+
"""
|
| 20 |
+
Class representing if the bounding box is groundtruth or not.
|
| 21 |
+
|
| 22 |
+
Developed by: Rafael Padilla
|
| 23 |
+
Last modification: May 24 2018
|
| 24 |
+
"""
|
| 25 |
+
GroundTruth = 1
|
| 26 |
+
Detected = 2
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class BBFormat(Enum):
|
| 30 |
+
"""
|
| 31 |
+
Class representing the format of a bounding box.
|
| 32 |
+
It can be (X,Y,width,height) => XYWH
|
| 33 |
+
or (X1,Y1,X2,Y2) => XYX2Y2
|
| 34 |
+
|
| 35 |
+
Developed by: Rafael Padilla
|
| 36 |
+
Last modification: May 24 2018
|
| 37 |
+
"""
|
| 38 |
+
XYWH = 1
|
| 39 |
+
XYX2Y2 = 2
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# size => (width, height) of the image
|
| 43 |
+
# box => (X1, X2, Y1, Y2) of the bounding box
|
| 44 |
+
def convertToRelativeValues(size, box):
|
| 45 |
+
dw = 1. / (size[0])
|
| 46 |
+
dh = 1. / (size[1])
|
| 47 |
+
cx = (box[1] + box[0]) / 2.0
|
| 48 |
+
cy = (box[3] + box[2]) / 2.0
|
| 49 |
+
w = box[1] - box[0]
|
| 50 |
+
h = box[3] - box[2]
|
| 51 |
+
x = cx * dw
|
| 52 |
+
y = cy * dh
|
| 53 |
+
w = w * dw
|
| 54 |
+
h = h * dh
|
| 55 |
+
# x,y => (bounding_box_center)/width_of_the_image
|
| 56 |
+
# w => bounding_box_width / width_of_the_image
|
| 57 |
+
# h => bounding_box_height / height_of_the_image
|
| 58 |
+
return (x, y, w, h)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# size => (width, height) of the image
|
| 62 |
+
# box => (centerX, centerY, w, h) of the bounding box relative to the image
|
| 63 |
+
def convertToAbsoluteValues(size, box):
|
| 64 |
+
# w_box = round(size[0] * box[2])
|
| 65 |
+
# h_box = round(size[1] * box[3])
|
| 66 |
+
xIn = round(((2 * float(box[0]) - float(box[2])) * size[0] / 2))
|
| 67 |
+
yIn = round(((2 * float(box[1]) - float(box[3])) * size[1] / 2))
|
| 68 |
+
xEnd = xIn + round(float(box[2]) * size[0])
|
| 69 |
+
yEnd = yIn + round(float(box[3]) * size[1])
|
| 70 |
+
if xIn < 0:
|
| 71 |
+
xIn = 0
|
| 72 |
+
if yIn < 0:
|
| 73 |
+
yIn = 0
|
| 74 |
+
if xEnd >= size[0]:
|
| 75 |
+
xEnd = size[0] - 1
|
| 76 |
+
if yEnd >= size[1]:
|
| 77 |
+
yEnd = size[1] - 1
|
| 78 |
+
return (xIn, yIn, xEnd, yEnd)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def add_bb_into_image(image, bb, color=(255, 0, 0), thickness=2, label=None):
|
| 82 |
+
r = int(color[0])
|
| 83 |
+
g = int(color[1])
|
| 84 |
+
b = int(color[2])
|
| 85 |
+
|
| 86 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 87 |
+
fontScale = 0.5
|
| 88 |
+
fontThickness = 1
|
| 89 |
+
|
| 90 |
+
x1, y1, x2, y2 = bb.getAbsoluteBoundingBox(BBFormat.XYX2Y2)
|
| 91 |
+
x1 = int(x1)
|
| 92 |
+
y1 = int(y1)
|
| 93 |
+
x2 = int(x2)
|
| 94 |
+
y2 = int(y2)
|
| 95 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), (b, g, r), thickness)
|
| 96 |
+
# Add label
|
| 97 |
+
if label is not None:
|
| 98 |
+
# Get size of the text box
|
| 99 |
+
(tw, th) = cv2.getTextSize(label, font, fontScale, fontThickness)[0]
|
| 100 |
+
# Top-left coord of the textbox
|
| 101 |
+
(xin_bb, yin_bb) = (x1 + thickness, y1 - th + int(12.5 * fontScale))
|
| 102 |
+
# Checking position of the text top-left (outside or inside the bb)
|
| 103 |
+
if yin_bb - th <= 0: # if outside the image
|
| 104 |
+
yin_bb = y1 + th # put it inside the bb
|
| 105 |
+
r_Xin = x1 - int(thickness / 2)
|
| 106 |
+
r_Yin = y1 - th - int(thickness / 2)
|
| 107 |
+
# Draw filled rectangle to put the text in it
|
| 108 |
+
cv2.rectangle(image, (r_Xin, r_Yin - thickness),
|
| 109 |
+
(r_Xin + tw + thickness * 3, r_Yin + th + int(12.5 * fontScale)), (b, g, r),
|
| 110 |
+
-1)
|
| 111 |
+
cv2.putText(image, label, (xin_bb, yin_bb), font, fontScale, (0, 0, 0), fontThickness,
|
| 112 |
+
cv2.LINE_AA)
|
| 113 |
+
return image
|
ScanSSD/README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ScanSSD: Scanning Single Shot Detector for Math in Document Images
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
A [PyTorch](http://pytorch.org/) implementation of ScanSSD [Scanning Single Shot MultiBox Detector](https://paragmali.me/scanning-single-shot-detector-for-math-in-document-images/) by [**Parag Mali**](https://github.com/MaliParag/). It was developed using SSD implementation by [**Max deGroot**](https://github.com/amdegroot).
|
| 5 |
+
|
| 6 |
+
All credit goes to the authors of the paper and the original implementation.
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
I have made some changes to the original implementation to make it work with the latest version of PyTorch and Python.
|
| 11 |
+
I have also removed some unnecessary files, in particular the ones related to dataset.
|