chextnet-raylabs / model.py
iruda21cse's picture
initial commit
3e745a7
import torch
import torch.nn as nn
from torchvision import models as models
class DenseNet121(nn.Module):
"""Model modified.
The architecture of our model is the same as standard DenseNet121
except the classifier layer which has an additional sigmoid function.
"""
def __init__(self, out_size):
super(DenseNet121, self).__init__()
self.densenet121 = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
num_ftrs = self.densenet121.classifier.in_features
self.densenet121.classifier = nn.Sequential(
nn.Linear(num_ftrs, out_size),
# nn.Sigmoid()
)
def forward(self, x):
x = self.densenet121(x)
return x
def load_model(ckpt_path, n_classes=14):
model = DenseNet121(n_classes).cpu()
print("=> loading checkpoint")
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'), weights_only=True)
new_state_dict = {}
for key, value in checkpoint.items():
new_key = key.replace("module.", "") # Remove 'module.' from keys
new_state_dict[new_key] = value
model.load_state_dict(new_state_dict)
print("=> loaded checkpoint")
model.eval()
return model