bytchew commited on
Commit
1c9d3b0
·
verified ·
1 Parent(s): 95af56c

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +18 -10
model.py CHANGED
@@ -1,16 +1,24 @@
1
  import torch
2
  import torch.nn as nn
3
  from torchvision import models
 
 
 
4
 
 
 
5
 
6
- def load_model(pretrained_weights_path):
7
- # Initialize Face-Rego
8
- net = models.resnet18(pretrained=False)
9
- num_ftrs = net.fc.in_features
10
- net.fc = nn.Linear(num_ftrs, 128) # Match your fine-tuned setup
11
 
12
- # Load weights
13
- state_dict = torch.load(pretrained_weights_path, map_location=torch.device('cpu'))
14
- net.load_state_dict(state_dict)
15
- net.eval() # Set to evaluation mode
16
- return net
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from torchvision import models
4
+ import huggingface_hub
5
+ from huggingface_hub import hf_hub_download
6
+ import torchvision
7
 
8
+ model_repo_id = "CSSE416-final-project/faceRecogModel"
9
+ weight_file_id = "modelWeights100.bin"
10
 
11
+ def load_model(repo_id):
12
+ # Download the model weights from the repo
13
+ weights_path = hf_hub_download(repo_id=model_repo_id, filename=weight_file_id)
 
 
14
 
15
+ # Initialize the ResNet-18 architecture
16
+ model = torchvision.models.resnet18(pretrained=True)
17
+ num_ftrs = model.fc.in_features
18
+ model.fc = nn.Linear(num_ftrs, 100) # Adjust for your task (e.g., 128 classes)
19
+
20
+ # Load the model weights
21
+ state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
22
+ model.load_state_dict(state_dict)
23
+ model.eval() # Set the model to evaluation mode
24
+ return model