Gaimundo commited on
Commit
4cc415c
·
verified ·
1 Parent(s): 02cd388

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -54
app.py DELETED
@@ -1,54 +0,0 @@
1
-
2
- import gradio as gr
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from torchvision import transforms
7
- from PIL import Image
8
-
9
- class MNISTNet(nn.Module):
10
- def __init__(self):
11
- super().__init__()
12
- self.conv1 = nn.Conv2d(1, 32, 3, 1)
13
- self.conv2 = nn.Conv2d(32, 64, 3, 1)
14
- self.dropout1 = nn.Dropout(0.25)
15
- self.dropout2 = nn.Dropout(0.5)
16
- self.fc1 = nn.Linear(9216, 128)
17
- self.fc2 = nn.Linear(128, 10)
18
- def forward(self, x):
19
- x = F.relu(self.conv1(x))
20
- x = F.relu(self.conv2(x))
21
- x = F.max_pool2d(x, 2)
22
- x = self.dropout1(x)
23
- x = torch.flatten(x, 1)
24
- x = F.relu(self.fc1(x))
25
- x = self.dropout2(x)
26
- x = self.fc2(x)
27
- return x
28
-
29
- model = MNISTNet()
30
- model.load_state_dict(torch.hub.load_state_dict_from_url('https://huggingface.co/username/mnist-cnn-pytorch/resolve/main/mnist_cnn.pt', map_location='cpu'))
31
- model.eval()
32
-
33
- transform = transforms.Compose([
34
- transforms.Grayscale(),
35
- transforms.Resize((28,28)),
36
- transforms.ToTensor(),
37
- transforms.Normalize((0.1307,), (0.3081,))
38
- ])
39
-
40
- def predict(image):
41
- image = transform(image).unsqueeze(0)
42
- with torch.no_grad():
43
- output = model(image)
44
- probs = torch.softmax(output, dim=1)[0]
45
- return {str(i): float(probs[i]) for i in range(10)}
46
-
47
- iface = gr.Interface(
48
- fn=predict,
49
- inputs=gr.Image(type="pil", shape=(28,28), image_mode="L", label="Draw a digit"),
50
- outputs=gr.Label(num_top_classes=10),
51
- title="MNIST Digit Classifier (PyTorch)"
52
- )
53
-
54
- iface.launch()