npv2k1 commited on
Commit
7d5fa6b
·
verified ·
1 Parent(s): c5a5751
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -18,7 +18,7 @@ def classify_drawing(drawing_image):
18
  num_classes = 3 # Set the number of classes
19
  # Initialize your model class
20
  model = ShapeClassifier(num_classes=num_classes)
21
- model.load_state_dict(torch.load('model.pth'))
22
  model.eval() # Set the model to evaluation mode
23
 
24
  # Convert the drawing to a grayscale image
 
18
  num_classes = 3 # Set the number of classes
19
  # Initialize your model class
20
  model = ShapeClassifier(num_classes=num_classes)
21
+ model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
22
  model.eval() # Set the model to evaluation mode
23
 
24
  # Convert the drawing to a grayscale image