Spaces:
Sleeping
Sleeping
File size: 1,417 Bytes
9b6af3b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import torch.nn as nn
from Model.attention import PositionAttention
from Model.backbone import VisionTransformer
class TTR(nn.Module):
def __init__(self, args: dict):
super().__init__()
self.args = args
self.backbone = VisionTransformer(img_size=args["img_size"],
patch_size=args["patch_size"],
in_channels=3,
embed_dim=args["embed_dim"],
num_heads=args["num_heads"],
mask_ratio=args["mask_ratio"])
self.positionAttention = PositionAttention(max_length=26,
in_channels=args["embed_dim"],
num_channels=args["position_attention_hidden"],
h=args["img_size"][0] // args["patch_size"][0],
w=args["img_size"][1] // args["patch_size"][1],
mode='nearest')
self.cls = nn.Linear(args["embed_dim"], 43)
return
def forward(self, image):
features = self.backbone(image)
attn_vecs, attn_scores = self.positionAttention(features)
logits = self.cls(attn_vecs)
return logits
|