Spaces:
Runtime error
Runtime error
Commit
·
2d29ff9
1
Parent(s):
1b8bd99
Changed the labels to english
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import gradio as gr
|
|
| 2 |
import torch
|
| 3 |
from Model import LeNet
|
| 4 |
|
| 5 |
-
labels = ['Zero','
|
| 6 |
|
| 7 |
# Locate device
|
| 8 |
if torch.cuda.is_available():
|
|
@@ -12,8 +12,6 @@ else:
|
|
| 12 |
device = torch.device("cpu")
|
| 13 |
print("CPU")
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
# Loading model
|
| 18 |
model = LeNet().to(device)
|
| 19 |
model.load_state_dict(torch.load("model_mnist.pth", map_location=torch.device('cpu')))
|
|
@@ -28,6 +26,6 @@ def predict(input):
|
|
| 28 |
confidences = {labels[i]: float(prediction[i]) for i in range(10)}
|
| 29 |
return confidences
|
| 30 |
|
| 31 |
-
gr.Interface(title='
|
| 32 |
inputs="sketchpad",
|
| 33 |
outputs=gr.Label(num_top_classes=3)).launch(share=False, debug=True)
|
|
|
|
| 2 |
import torch
|
| 3 |
from Model import LeNet
|
| 4 |
|
| 5 |
+
labels = ['Zero','One','Two','Three','Four','Five','Six','Seven','Eight', 'Nine']
|
| 6 |
|
| 7 |
# Locate device
|
| 8 |
if torch.cuda.is_available():
|
|
|
|
| 12 |
device = torch.device("cpu")
|
| 13 |
print("CPU")
|
| 14 |
|
|
|
|
|
|
|
| 15 |
# Loading model
|
| 16 |
model = LeNet().to(device)
|
| 17 |
model.load_state_dict(torch.load("model_mnist.pth", map_location=torch.device('cpu')))
|
|
|
|
| 26 |
confidences = {labels[i]: float(prediction[i]) for i in range(10)}
|
| 27 |
return confidences
|
| 28 |
|
| 29 |
+
gr.Interface(title='Digit classifier', fn=predict,
|
| 30 |
inputs="sketchpad",
|
| 31 |
outputs=gr.Label(num_top_classes=3)).launch(share=False, debug=True)
|