import torch from torchvision.models import resnet18, ResNet18_Weights from torch import nn def load_face_classifier_model(model_path: str = 'model_2.pth', num_classes: int = 5): """ Loads the pre-trained ResNet18 model, modifies the final layer, loads the state dictionary, and sets the model to evaluation mode. Args: model_path (str): Path to the saved model state dictionary. num_classes (int): Number of classes for the final linear layer. Returns: torch.nn.Module: The loaded model in evaluation mode. """ # Load the pre-trained ResNet18 model with specified weights weights = ResNet18_Weights.IMAGENET1K_V1 model = resnet18(weights=weights) # Modify the final fully connected layer for the specified number of classes num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, num_classes) # Load the saved state dictionary state_dict = torch.load(model_path, map_location=torch.device('cpu')) # Load to CPU # Adjust keys to match the model (if necessary, based on how the model was saved) # This adjustment is based on the observation from the previous failed attempt. new_state_dict = {} for k, v in state_dict.items(): if 'fc.1.' in k: new_key = k.replace('fc.1.', 'fc.') new_state_dict[new_key] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict) # Set the model to evaluation mode model.eval() return model if __name__ == '__main__': # Example usage (for testing) loaded_model = load_face_classifier_model() print("Model loaded successfully:") print(loaded_model)