russian_monument / model.py
Pafkun333's picture
Commiting first one
aeaf3f3
import torch
from torchvision.models import resnet18, ResNet18_Weights
from torch import nn
def load_face_classifier_model(model_path: str = 'model_2.pth', num_classes: int = 5):
"""
Loads the pre-trained ResNet18 model, modifies the final layer,
loads the state dictionary, and sets the model to evaluation mode.
Args:
model_path (str): Path to the saved model state dictionary.
num_classes (int): Number of classes for the final linear layer.
Returns:
torch.nn.Module: The loaded model in evaluation mode.
"""
# Load the pre-trained ResNet18 model with specified weights
weights = ResNet18_Weights.IMAGENET1K_V1
model = resnet18(weights=weights)
# Modify the final fully connected layer for the specified number of classes
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
# Load the saved state dictionary
state_dict = torch.load(model_path, map_location=torch.device('cpu')) # Load to CPU
# Adjust keys to match the model (if necessary, based on how the model was saved)
# This adjustment is based on the observation from the previous failed attempt.
new_state_dict = {}
for k, v in state_dict.items():
if 'fc.1.' in k:
new_key = k.replace('fc.1.', 'fc.')
new_state_dict[new_key] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
# Set the model to evaluation mode
model.eval()
return model
if __name__ == '__main__':
# Example usage (for testing)
loaded_model = load_face_classifier_model()
print("Model loaded successfully:")
print(loaded_model)