Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision.models import resnet18, ResNet18_Weights | |
| from torch import nn | |
| def load_face_classifier_model(model_path: str = 'model_2.pth', num_classes: int = 5): | |
| """ | |
| Loads the pre-trained ResNet18 model, modifies the final layer, | |
| loads the state dictionary, and sets the model to evaluation mode. | |
| Args: | |
| model_path (str): Path to the saved model state dictionary. | |
| num_classes (int): Number of classes for the final linear layer. | |
| Returns: | |
| torch.nn.Module: The loaded model in evaluation mode. | |
| """ | |
| # Load the pre-trained ResNet18 model with specified weights | |
| weights = ResNet18_Weights.IMAGENET1K_V1 | |
| model = resnet18(weights=weights) | |
| # Modify the final fully connected layer for the specified number of classes | |
| num_ftrs = model.fc.in_features | |
| model.fc = nn.Linear(num_ftrs, num_classes) | |
| # Load the saved state dictionary | |
| state_dict = torch.load(model_path, map_location=torch.device('cpu')) # Load to CPU | |
| # Adjust keys to match the model (if necessary, based on how the model was saved) | |
| # This adjustment is based on the observation from the previous failed attempt. | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if 'fc.1.' in k: | |
| new_key = k.replace('fc.1.', 'fc.') | |
| new_state_dict[new_key] = v | |
| else: | |
| new_state_dict[k] = v | |
| model.load_state_dict(new_state_dict) | |
| # Set the model to evaluation mode | |
| model.eval() | |
| return model | |
| if __name__ == '__main__': | |
| # Example usage (for testing) | |
| loaded_model = load_face_classifier_model() | |
| print("Model loaded successfully:") | |
| print(loaded_model) | |