import timm import os from Backbone.GetPromptModel import build_promptmodel from pprint import pprint def get_PuzzleTuning_VPT_model(num_classes=0, edge_size=224, prompt_state_dict=None, base_state_dict='timm'): """ :param num_classes: classification required number of your dataset, 0 for taking the feature :param edge_size: the input edge size of the dataloder :param model_idx: the model we are going to use. by the format of Model_size_other_info :param pretrained_backbone: The backbone CNN is initiate randomly or by its official Pretrained models :return: prepared model """ model = build_promptmodel( num_classes=0, # set to feature extractor model, output is CLS token edge_size=edge_size, model_idx='ViT', patch_size=16, Prompt_Token_num=20, VPT_type="Deep", prompt_state_dict=prompt_state_dict, base_state_dict=base_state_dict) return model