meraj12 commited on
Commit
ffdb7d5
·
verified ·
1 Parent(s): 0e8b499

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -9,7 +9,8 @@ import tempfile
9
  import pyttsx3
10
  import os
11
  from datetime import datetime
12
-
 
13
  # ========== TTS ENGINE ==========
14
  def speak(text):
15
  engine = pyttsx3.init()
@@ -25,11 +26,13 @@ if "PKR" not in currency_type:
25
  st.stop()
26
 
27
  # ========== LOAD MODEL ==========
28
- model = models.mobilenet_v2(pretrained=False)
29
- model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, 2)
30
- model.load_state_dict(torch.load("pkr_currency_classifier.pt", map_location=torch.device('cpu')))
 
31
  model.eval()
32
 
 
33
  # ========== TRANSFORMS ==========
34
  transform = transforms.Compose([
35
  transforms.Resize((224, 224)),
 
9
  import pyttsx3
10
  import os
11
  from datetime import datetime
12
+ import torch
13
+ from torchvision.models.mobilenetv2 import MobileNetV2
14
  # ========== TTS ENGINE ==========
15
  def speak(text):
16
  engine = pyttsx3.init()
 
26
  st.stop()
27
 
28
  # ========== LOAD MODEL ==========
29
+
30
+ torch.serialization.add_safe_globals({'MobileNetV2': MobileNetV2})
31
+
32
+ model = torch.load("pkr_currency_classifier.pt", map_location='cpu', weights_only=False)
33
  model.eval()
34
 
35
+
36
  # ========== TRANSFORMS ==========
37
  transform = transforms.Compose([
38
  transforms.Resize((224, 224)),