Rahuletto commited on
Commit
0e4e702
·
1 Parent(s): 57ff00c

feat: gradio test

Browse files
Files changed (10) hide show
  1. .gitignore +4 -1
  2. examples/1.png +3 -0
  3. examples/2.png +3 -0
  4. examples/3.png +3 -0
  5. examples/4.png +3 -0
  6. examples/5.png +3 -0
  7. examples/6.png +3 -0
  8. examples/7.png +3 -0
  9. main.py +76 -0
  10. requirements.txt +1 -0
.gitignore CHANGED
@@ -8,5 +8,8 @@ wheels/
8
 
9
  # Virtual environments
10
  .venv
 
11
  .DS_Store
12
- cifar/
 
 
 
8
 
9
  # Virtual environments
10
  .venv
11
+
12
  .DS_Store
13
+ cifar/
14
+
15
+ .gradio/
examples/1.png ADDED

Git LFS Details

  • SHA256: 43e89de4d62730a0b78d995b959ea1808af08e37ae70ab33bbafaf39623bf359
  • Pointer size: 131 Bytes
  • Size of remote file: 753 kB
examples/2.png ADDED

Git LFS Details

  • SHA256: 40da165ec8aad12aafef1851b2f6a13b5acbc1a35be4659293c2939e66cef679
  • Pointer size: 131 Bytes
  • Size of remote file: 372 kB
examples/3.png ADDED

Git LFS Details

  • SHA256: 89970da69441903f0b6b139c96daabc904f0955cb20b90a8262af1480abca57c
  • Pointer size: 130 Bytes
  • Size of remote file: 72.2 kB
examples/4.png ADDED

Git LFS Details

  • SHA256: 33a3ed9cfd749c3345a70b5c3dc41a339e1e7b368baea3683502347a98b07be0
  • Pointer size: 131 Bytes
  • Size of remote file: 341 kB
examples/5.png ADDED

Git LFS Details

  • SHA256: 991754fdb6188526802e3b24e8bf6fe880c45443d9cf579b3b8ac156e3ccd430
  • Pointer size: 131 Bytes
  • Size of remote file: 276 kB
examples/6.png ADDED

Git LFS Details

  • SHA256: 6ebdac0fb95095cef1e42736501efac3613894400b6ff7a40e28c1d684ec93a2
  • Pointer size: 131 Bytes
  • Size of remote file: 224 kB
examples/7.png ADDED

Git LFS Details

  • SHA256: 3e218338198762035f64258e96cd9670975e6359b7365b3f0c74fefa6d92ebe2
  • Pointer size: 131 Bytes
  • Size of remote file: 280 kB
main.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from cnn import CNN
7
+
8
+ device = torch.device(
9
+ "cuda"
10
+ if torch.cuda.is_available()
11
+ else "mps"
12
+ if torch.backends.mps.is_available()
13
+ else "cpu"
14
+ )
15
+
16
+ classes = [
17
+ "airplane",
18
+ "automobile",
19
+ "bird",
20
+ "cat",
21
+ "deer",
22
+ "dog",
23
+ "frog",
24
+ "horse",
25
+ "ship",
26
+ "truck",
27
+ ]
28
+
29
+ model = CNN()
30
+ model.load_state_dict(torch.load("cnn/model.pt", map_location=device))
31
+ model.to(device)
32
+ model.eval()
33
+
34
+ transform = transforms.Compose(
35
+ [
36
+ transforms.Resize((32, 32)),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
39
+ ]
40
+ )
41
+
42
+
43
+ def predict(image):
44
+ if image is None:
45
+ return {}
46
+
47
+ image = Image.fromarray(image).convert("RGB")
48
+ image_tensor = transform(image)
49
+ image_tensor = (image_tensor).unsqueeze(0).to(device)
50
+
51
+ with torch.no_grad():
52
+ outputs = model(image_tensor)
53
+ probabilities = F.softmax(outputs, dim=1)[0]
54
+
55
+ return {classes[i]: float(probabilities[i]) for i in range(len(classes))}
56
+
57
+
58
+ demo = gr.Interface(
59
+ fn=predict,
60
+ inputs=gr.Image(type="numpy"),
61
+ outputs=gr.Label(num_top_classes=10),
62
+ title="CNN Classifier",
63
+ description="Upload an image to classify it into one of 10 CIFAR-10 categories: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck",
64
+ examples=[
65
+ ["examples/1.png"],
66
+ ["examples/2.png"],
67
+ ["examples/3.png"],
68
+ ["examples/4.png"],
69
+ ["examples/5.png"],
70
+ ["examples/6.png"],
71
+ ["examples/7.png"],
72
+ ],
73
+ )
74
+
75
+ if __name__ == "__main__":
76
+ demo.launch(share=True, pwa=True)
requirements.txt CHANGED
@@ -49,3 +49,4 @@ tornado==6.5.1
49
  traitlets==5.14.3
50
  typing-extensions==4.14.0
51
  wcwidth==0.2.13
 
 
49
  traitlets==5.14.3
50
  typing-extensions==4.14.0
51
  wcwidth==0.2.13
52
+ gradio