Spaces:
Sleeping
Sleeping
File size: 1,639 Bytes
df9c255 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import torch
import timm
import torch.nn as nn
from torchvision.models import alexnet
def classifying_head(in_features: int, num_labels: int):
return nn.Sequential(
nn.Dropout(p=0.2),
nn.Linear(in_features=in_features, out_features=128),
nn.ReLU(),
nn.BatchNorm1d(num_features=128),
nn.Linear(128, num_labels),
)
def load_model(ckpt_path, num_labels, model_str):
checkpoint = torch.load(
ckpt_path, map_location=torch.device('cpu'))
state_dict = checkpoint['state_dict']
if model_str == "densenet121":
model = timm.create_model(
'densenet121', num_classes=num_labels, pretrained=True)
model.classifier = classifying_head(1024, num_labels)
elif model_str == "swin_simim" or model_str == "swin_in22k":
model = timm.create_model(
'swin_base_patch4_window7_224_in22k', num_classes=num_labels, pretrained=True)
elif model_str == "vit_in1k":
model = timm.create_model('vit_base_patch16_224',
num_classes=num_labels, pretrained=True)
if model_str == "swin_simim":
normalization = "chestx-ray"
else:
normalization = "imagenet"
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint['state_dict']
msg = model.load_state_dict(state_dict, strict=False)
print(f'Loaded {model_str} with msg: {msg}')
img_size = 224
if model is None or normalization is None or img_size is None:
raise ValueError("Model, normalization, or image size not found")
return model, normalization, img_size
|