ttoosi commited on
Commit
741020d
·
1 Parent(s): ae3f158
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -10,7 +10,7 @@ checkpoint_path = hf_hub_download(repo_id="ttoosi/resnet50_robust_face", filenam
10
 
11
  # Initialize the model
12
  model = models.resnet50()
13
- model.load_state_dict(torch.load(checkpoint_path))
14
  model.eval()
15
 
16
  # Image preprocessing
@@ -25,7 +25,7 @@ preprocess = transforms.Compose([
25
  def predict(image):
26
  image = preprocess(image).unsqueeze(0) # Add batch dimension
27
  with torch.no_grad():
28
- output = model(image)
29
  _, predicted_class = output.max(1)
30
  return f"Predicted class: {predicted_class.item()}"
31
 
 
10
 
11
  # Initialize the model
12
  model = models.resnet50()
13
+ model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))) # Force model to load on CPU
14
  model.eval()
15
 
16
  # Image preprocessing
 
25
  def predict(image):
26
  image = preprocess(image).unsqueeze(0) # Add batch dimension
27
  with torch.no_grad():
28
+ output = model(image) # Perform inference on CPU
29
  _, predicted_class = output.max(1)
30
  return f"Predicted class: {predicted_class.item()}"
31