MViT-TR / Model /model.py
serdaryildiz's picture
Upload 31 files
9b6af3b verified
import torch.nn as nn
from Model.attention import PositionAttention
from Model.backbone import VisionTransformer
class TTR(nn.Module):
def __init__(self, args: dict):
super().__init__()
self.args = args
self.backbone = VisionTransformer(img_size=args["img_size"],
patch_size=args["patch_size"],
in_channels=3,
embed_dim=args["embed_dim"],
num_heads=args["num_heads"],
mask_ratio=args["mask_ratio"])
self.positionAttention = PositionAttention(max_length=26,
in_channels=args["embed_dim"],
num_channels=args["position_attention_hidden"],
h=args["img_size"][0] // args["patch_size"][0],
w=args["img_size"][1] // args["patch_size"][1],
mode='nearest')
self.cls = nn.Linear(args["embed_dim"], 43)
return
def forward(self, image):
features = self.backbone(image)
attn_vecs, attn_scores = self.positionAttention(features)
logits = self.cls(attn_vecs)
return logits