import torch import timm DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_model(model_path, model_type="vit"): if model_type == "clip": model = timm.create_model("vit_base_patch16_clip_224.openai", pretrained=False, num_classes=1) else: model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=1) checkpoint = torch.load(model_path, map_location="cpu") if isinstance(checkpoint, dict): state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint)) else: state_dict = checkpoint model.load_state_dict(state_dict, strict=False) model.to(DEVICE) model.eval() return model