Vocal-Eyes-Fast_api / RelTR_build.py
ABDRauf's picture
changed device
203e1df verified
import torch
import torch.nn as nn
from models.backbone import Backbone, Joiner
from models.position_encoding import PositionEmbeddingSine
from models.transformer import Transformer
from models.reltr import RelTR
def Reltr_model():
position_embedding = PositionEmbeddingSine(128, normalize=True)
backbone = Backbone('resnet50', False, False, False)
backbone = Joiner(backbone, position_embedding)
backbone.num_channels = 2048
transformer = Transformer(d_model=256, dropout=0.1, nhead=8,
dim_feedforward=2048,
num_encoder_layers=6,
num_decoder_layers=6,
normalize_before=False,
return_intermediate_dec=True)
model = RelTR(backbone, transformer, num_classes=151, num_rel_classes = 51,
num_entities=100, num_triplets=200)
# The checkpoint is pretrained on Visual Genome
ckpt = torch.load("checkpoint0149.pth", map_location=torch.device("cpu"), weights_only=False)
model.load_state_dict(ckpt['model'])
return model