Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from torchvision import models as models | |
| class DenseNet121(nn.Module): | |
| """Model modified. | |
| The architecture of our model is the same as standard DenseNet121 | |
| except the classifier layer which has an additional sigmoid function. | |
| """ | |
| def __init__(self, out_size): | |
| super(DenseNet121, self).__init__() | |
| self.densenet121 = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT) | |
| num_ftrs = self.densenet121.classifier.in_features | |
| self.densenet121.classifier = nn.Sequential( | |
| nn.Linear(num_ftrs, out_size), | |
| # nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| x = self.densenet121(x) | |
| return x | |
| def load_model(ckpt_path, n_classes=14): | |
| model = DenseNet121(n_classes).cpu() | |
| print("=> loading checkpoint") | |
| checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'), weights_only=True) | |
| new_state_dict = {} | |
| for key, value in checkpoint.items(): | |
| new_key = key.replace("module.", "") # Remove 'module.' from keys | |
| new_state_dict[new_key] = value | |
| model.load_state_dict(new_state_dict) | |
| print("=> loaded checkpoint") | |
| model.eval() | |
| return model | |