import torch import torch.nn as nn from torchvision import models def load_model(pretrained_weights_path): # Initialize Face-Rego net = models.resnet18(pretrained=False) num_ftrs = net.fc.in_features net.fc = nn.Linear(num_ftrs, 100) # Match your fine-tuned setup # Load weights state_dict = torch.load(pretrained_weights_path, map_location=torch.device('cpu')) net.load_state_dict(state_dict) net.eval() # Set to evaluation mode return net