import torch import torch.nn as nn from models.backbone import Backbone, Joiner from models.position_encoding import PositionEmbeddingSine from models.transformer import Transformer from models.reltr import RelTR def Reltr_model(): position_embedding = PositionEmbeddingSine(128, normalize=True) backbone = Backbone('resnet50', False, False, False) backbone = Joiner(backbone, position_embedding) backbone.num_channels = 2048 transformer = Transformer(d_model=256, dropout=0.1, nhead=8, dim_feedforward=2048, num_encoder_layers=6, num_decoder_layers=6, normalize_before=False, return_intermediate_dec=True) model = RelTR(backbone, transformer, num_classes=151, num_rel_classes = 51, num_entities=100, num_triplets=200) # The checkpoint is pretrained on Visual Genome ckpt = torch.load("checkpoint0149.pth", map_location=torch.device("cpu"), weights_only=False) model.load_state_dict(ckpt['model']) return model