sherab65 commited on
Commit
21227aa
·
verified ·
1 Parent(s): fff351d

git commit -m "app.py"

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -1,20 +1,26 @@
1
  import torch
2
  from torchvision import transforms
 
3
  from PIL import Image
4
  import gradio as gr
5
 
6
- # Allow safe loading of torchvision ResNet class
7
- from torchvision.models import resnet18
8
- torch.serialization.add_safe_globals([resnet18])
9
-
10
  # Class labels
11
  class_names = ['Nu.1', 'Nu.10', 'Nu.100', 'Nu.1000', 'Nu.20', 'Nu.5', 'Nu.50', 'Nu.500']
12
 
13
- # Force CPU (Hugging Face Spaces do not support GPU)
14
  device = torch.device('cpu')
15
 
16
- # Load full model (code-safe if trusted model)
17
- model = torch.load("currency_model.pth", map_location=device, weights_only=False)
 
 
 
 
 
 
 
 
 
18
  model.eval()
19
 
20
  # Image transform
@@ -23,7 +29,7 @@ transform = transforms.Compose([
23
  transforms.ToTensor(),
24
  ])
25
 
26
- # Prediction logic
27
  def predict(image):
28
  image = image.convert("RGB")
29
  image = transform(image).unsqueeze(0).to(device)
@@ -33,7 +39,7 @@ def predict(image):
33
  _, predicted = torch.max(outputs, 1)
34
  return class_names[predicted.item()]
35
 
36
- # Gradio Interface
37
  interface = gr.Interface(
38
  fn=predict,
39
  inputs=gr.Image(type="pil"),
@@ -42,4 +48,4 @@ interface = gr.Interface(
42
  description="Upload a currency note image to identify its value."
43
  )
44
 
45
- interface.launch()
 
1
  import torch
2
  from torchvision import transforms
3
+ from torchvision.models import resnet18
4
  from PIL import Image
5
  import gradio as gr
6
 
 
 
 
 
7
  # Class labels
8
  class_names = ['Nu.1', 'Nu.10', 'Nu.100', 'Nu.1000', 'Nu.20', 'Nu.5', 'Nu.50', 'Nu.500']
9
 
10
+ # Force CPU
11
  device = torch.device('cpu')
12
 
13
+ # Step 1: Define model architecture
14
+ model = resnet18(pretrained=False)
15
+
16
+ # Step 2: Modify final layer (assuming 8 classes)
17
+ model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
18
+
19
+ # Step 3: Load weights
20
+ model.load_state_dict(torch.load("currency_model.pth", map_location=device))
21
+
22
+ # Step 4: Set to eval mode
23
+ model.to(device)
24
  model.eval()
25
 
26
  # Image transform
 
29
  transforms.ToTensor(),
30
  ])
31
 
32
+ # Prediction function
33
  def predict(image):
34
  image = image.convert("RGB")
35
  image = transform(image).unsqueeze(0).to(device)
 
39
  _, predicted = torch.max(outputs, 1)
40
  return class_names[predicted.item()]
41
 
42
+ # Gradio interface
43
  interface = gr.Interface(
44
  fn=predict,
45
  inputs=gr.Image(type="pil"),
 
48
  description="Upload a currency note image to identify its value."
49
  )
50
 
51
+ interface.launch()