Spaces:
Sleeping
Sleeping
| import torch | |
| import timm | |
| import torch.nn as nn | |
| from torchvision.models import alexnet | |
| def classifying_head(in_features: int, num_labels: int): | |
| return nn.Sequential( | |
| nn.Dropout(p=0.2), | |
| nn.Linear(in_features=in_features, out_features=128), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(num_features=128), | |
| nn.Linear(128, num_labels), | |
| ) | |
| def load_model(ckpt_path, num_labels, model_str): | |
| checkpoint = torch.load( | |
| ckpt_path, map_location=torch.device('cpu')) | |
| state_dict = checkpoint['state_dict'] | |
| if model_str == "densenet121": | |
| model = timm.create_model( | |
| 'densenet121', num_classes=num_labels, pretrained=True) | |
| model.classifier = classifying_head(1024, num_labels) | |
| elif model_str == "swin_simim" or model_str == "swin_in22k": | |
| model = timm.create_model( | |
| 'swin_base_patch4_window7_224_in22k', num_classes=num_labels, pretrained=True) | |
| elif model_str == "vit_in1k": | |
| model = timm.create_model('vit_base_patch16_224', | |
| num_classes=num_labels, pretrained=True) | |
| if model_str == "swin_simim": | |
| normalization = "chestx-ray" | |
| else: | |
| normalization = "imagenet" | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| state_dict = checkpoint['state_dict'] | |
| msg = model.load_state_dict(state_dict, strict=False) | |
| print(f'Loaded {model_str} with msg: {msg}') | |
| img_size = 224 | |
| if model is None or normalization is None or img_size is None: | |
| raise ValueError("Model, normalization, or image size not found") | |
| return model, normalization, img_size | |