faceRecogModel / model.py
bytchew's picture
Change back to number of classes
0b4aeed verified
raw
history blame
481 Bytes
import torch
import torch.nn as nn
from torchvision import models
def load_model(pretrained_weights_path):
# Initialize Face-Rego
net = models.resnet18(pretrained=False)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 100) # Match your fine-tuned setup
# Load weights
state_dict = torch.load(pretrained_weights_path, map_location=torch.device('cpu'))
net.load_state_dict(state_dict)
net.eval() # Set to evaluation mode
return net