x-ray-vision / src /load_model.py
taheera's picture
initial commit
df9c255
import torch
import timm
import torch.nn as nn
from torchvision.models import alexnet
def classifying_head(in_features: int, num_labels: int):
return nn.Sequential(
nn.Dropout(p=0.2),
nn.Linear(in_features=in_features, out_features=128),
nn.ReLU(),
nn.BatchNorm1d(num_features=128),
nn.Linear(128, num_labels),
)
def load_model(ckpt_path, num_labels, model_str):
checkpoint = torch.load(
ckpt_path, map_location=torch.device('cpu'))
state_dict = checkpoint['state_dict']
if model_str == "densenet121":
model = timm.create_model(
'densenet121', num_classes=num_labels, pretrained=True)
model.classifier = classifying_head(1024, num_labels)
elif model_str == "swin_simim" or model_str == "swin_in22k":
model = timm.create_model(
'swin_base_patch4_window7_224_in22k', num_classes=num_labels, pretrained=True)
elif model_str == "vit_in1k":
model = timm.create_model('vit_base_patch16_224',
num_classes=num_labels, pretrained=True)
if model_str == "swin_simim":
normalization = "chestx-ray"
else:
normalization = "imagenet"
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint['state_dict']
msg = model.load_state_dict(state_dict, strict=False)
print(f'Loaded {model_str} with msg: {msg}')
img_size = 224
if model is None or normalization is None or img_size is None:
raise ValueError("Model, normalization, or image size not found")
return model, normalization, img_size