Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from models.backbone import Backbone, Joiner | |
| from models.position_encoding import PositionEmbeddingSine | |
| from models.transformer import Transformer | |
| from models.reltr import RelTR | |
| def Reltr_model(): | |
| position_embedding = PositionEmbeddingSine(128, normalize=True) | |
| backbone = Backbone('resnet50', False, False, False) | |
| backbone = Joiner(backbone, position_embedding) | |
| backbone.num_channels = 2048 | |
| transformer = Transformer(d_model=256, dropout=0.1, nhead=8, | |
| dim_feedforward=2048, | |
| num_encoder_layers=6, | |
| num_decoder_layers=6, | |
| normalize_before=False, | |
| return_intermediate_dec=True) | |
| model = RelTR(backbone, transformer, num_classes=151, num_rel_classes = 51, | |
| num_entities=100, num_triplets=200) | |
| # The checkpoint is pretrained on Visual Genome | |
| ckpt = torch.load("checkpoint0149.pth", map_location=torch.device("cpu"), weights_only=False) | |
| model.load_state_dict(ckpt['model']) | |
| return model | |