Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| from Model.attention import PositionAttention | |
| from Model.backbone import VisionTransformer | |
| class TTR(nn.Module): | |
| def __init__(self, args: dict): | |
| super().__init__() | |
| self.args = args | |
| self.backbone = VisionTransformer(img_size=args["img_size"], | |
| patch_size=args["patch_size"], | |
| in_channels=3, | |
| embed_dim=args["embed_dim"], | |
| num_heads=args["num_heads"], | |
| mask_ratio=args["mask_ratio"]) | |
| self.positionAttention = PositionAttention(max_length=26, | |
| in_channels=args["embed_dim"], | |
| num_channels=args["position_attention_hidden"], | |
| h=args["img_size"][0] // args["patch_size"][0], | |
| w=args["img_size"][1] // args["patch_size"][1], | |
| mode='nearest') | |
| self.cls = nn.Linear(args["embed_dim"], 43) | |
| return | |
| def forward(self, image): | |
| features = self.backbone(image) | |
| attn_vecs, attn_scores = self.positionAttention(features) | |
| logits = self.cls(attn_vecs) | |
| return logits | |