import torch #from models import vision_transformer as vit #from models import vision_transformer_multiBlocks_20221030 as vit #from methods import vision_transformer_multiBlocks_20221030 as vit from methods import ViT as vit #import vision_transformer_multiBlocks_20221030 as vit #from models.pmf_protonet import ProtoNet #from methods.pmf_protonet import ProtoNet from methods.protonet import ProtoNet #from pmf_protonet import ProtoNet #from models.cvpr2023_gnnnet_20221102 import GnnNet #from methods.cvpr2023_gnnnet_20221102 import GnnNet #from cvpr2023_gnnnet_20221102 import GnnNet def load_ViTsmall(no_pretrain=False): model = vit.__dict__['vit_small'](patch_size=16, num_classes=0) if(not no_pretrain): url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) model.load_state_dict(state_dict, strict=True) #print('Pretrained weights found at {}'.format(url)) #print('model defined.') return model def load_ViTbase(no_pretrain=False): model = vit.__dict__['vit_base'](patch_size=16, num_classes=0) if(not no_pretrain): url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) model.load_state_dict(state_dict, strict=True) print('Pretrained weights found at {}'.format(url)) print('model defined.') return model def load_ResNet50(no_pretrain=False): from torchvision.models.resnet import resnet50 pretrained = not no_pretrain model = resnet50(pretrained=pretrained) model.fc = torch.nn.Identity() print('model defined.') return model def load_ResNet50_dino(no_pretrain=False): from torchvision.models.resnet import resnet50 model = resnet50(pretrained=False) model.fc = torch.nn.Identity() if not no_pretrain: state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",map_location="cpu",) model.load_state_dict(state_dict, strict=False) return model def load_ResNet50_clip(no_pretrain=False): from models import clip model, _ = clip.load('RN50', 'cpu') return model def get_model(backbone='vit_small', classifier='protonet', args=None, styleAdv=False): if(backbone=='vit_small' and classifier == 'protonet'): extractor = load_ViTsmall() if(not styleAdv): #from models.pmf_protonet import ProtoNet from methods.protonet import ProtoNet model = ProtoNet(extractor) else: #from models.pmf_protonet_metatrain_vit_protonet_20221102 import ProtoNet #from methods.pmf_protonet_metatrain_vit_protonet_20221102 import ProtoNet from methods.StyleAdv_ViT_protonet import ProtoNet model = ProtoNet(extractor) if(backbone=='resnet50' and classifier == 'protonet'): extractor = load_ResNet50_dino() model = ProtoNet(extractor) if(backbone=='vit_small' and classifier == 'gnnnet'): extractor = load_ViTsmall() model = GnnNet(extractor, backbone_flag='vit_small', n_way = 5, n_support = args.nSupport) if(backbone=='resnet50' and classifier == 'gnnnet'): extractor = load_ResNet50_dino() model = GnnNet(extractor, backbone_flag='resnet50', n_way = 5, n_support = args.nSupport) return model if __name__ == '__main__': input = torch.randn(16, 3, 224, 224) print('input:', input.size()) model = load_ViTsmall() out = model(input) print('out:', out.size())