StoneSeller commited on
Commit
4348a2e
·
verified ·
1 Parent(s): f55c28e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -16
app.py CHANGED
@@ -1,29 +1,40 @@
1
- import os
2
  import subprocess
3
  import sys
4
 
 
 
 
 
 
5
  try:
6
  import torch
7
  except ImportError:
8
- subprocess.check_call([sys.executable, "-m", "pip", "install",
9
- "torch==2.0.1+cpu",
10
- "torchvision==0.15.2+cpu",
11
- "-f", "https://download.pytorch.org/whl/torch_stable.html"])
12
 
 
13
  try:
14
  import numpy as np
15
  except ImportError:
16
- subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2"])
 
 
 
 
 
 
17
 
 
18
  import torch
19
- import numpy
20
  import torch.nn as nn
 
21
  import torchvision.transforms as transforms
22
  from PIL import Image
23
  import gradio as gr
24
 
25
 
26
-
27
  class ModifiedLargeNet(nn.Module):
28
  def __init__(self):
29
  super(ModifiedLargeNet, self).__init__()
@@ -32,7 +43,7 @@ class ModifiedLargeNet(nn.Module):
32
  self.pool = nn.MaxPool2d(2, 2)
33
  self.conv2 = nn.Conv2d(5, 10, 5)
34
  self.fc1 = nn.Linear(10 * 29 * 29, 32)
35
- self.fc2 = nn.Linear(32, 3) # classify into "Rope"/"Hammer"/"others"
36
 
37
  def forward(self, x):
38
  x = self.pool(F.relu(self.conv1(x)))
@@ -40,38 +51,44 @@ class ModifiedLargeNet(nn.Module):
40
  x = x.view(-1, 10 * 29 * 29)
41
  x = F.relu(self.fc1(x))
42
  x = self.fc2(x)
43
- x = x.squeeze(1) # Flatten to [batch_size]
44
  return x
45
 
46
 
 
47
  model = ModifiedLargeNet()
48
  model.load_state_dict(torch.load("modified_large_net.pt", map_location=torch.device("cpu")))
49
  model.eval()
50
 
51
-
52
  transform = transforms.Compose([
53
  transforms.Resize((128, 128)),
54
  transforms.ToTensor(),
55
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
56
  ])
57
 
58
-
59
  def predict(image):
 
 
 
60
 
61
- image = transform(image).unsqueeze(0)
 
62
  with torch.no_grad():
63
  outputs = model(image)
64
  probabilities = torch.softmax(outputs, dim=1).numpy()[0]
65
  classes = ["Rope", "Hammer", "Other"]
66
  return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
67
 
 
68
  interface = gr.Interface(
69
  fn=predict,
70
- inputs=gr.Image(type="pil"),
71
- outputs=gr.Label(num_top_classes=3),
72
  title="Mechanical Tools Classifier",
73
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
74
  )
75
 
 
76
  if __name__ == "__main__":
77
- interface.launch()
 
 
1
  import subprocess
2
  import sys
3
 
4
+ # Ensure required libraries are installed
5
+ def install(package):
6
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
7
+
8
+ # Install torch and torchvision
9
  try:
10
  import torch
11
  except ImportError:
12
+ install("torch==2.0.1+cpu")
13
+ install("torchvision==0.15.2+cpu")
14
+ install("-f https://download.pytorch.org/whl/torch_stable.html")
 
15
 
16
+ # Install numpy
17
  try:
18
  import numpy as np
19
  except ImportError:
20
+ install("numpy<2")
21
+
22
+ # Install Pillow
23
+ try:
24
+ from PIL import Image
25
+ except ImportError:
26
+ install("Pillow==9.5.0")
27
 
28
+ # Imports
29
  import torch
 
30
  import torch.nn as nn
31
+ import torch.nn.functional as F
32
  import torchvision.transforms as transforms
33
  from PIL import Image
34
  import gradio as gr
35
 
36
 
37
+ # Define the model
38
  class ModifiedLargeNet(nn.Module):
39
  def __init__(self):
40
  super(ModifiedLargeNet, self).__init__()
 
43
  self.pool = nn.MaxPool2d(2, 2)
44
  self.conv2 = nn.Conv2d(5, 10, 5)
45
  self.fc1 = nn.Linear(10 * 29 * 29, 32)
46
+ self.fc2 = nn.Linear(32, 3) # classify into "Rope"/"Hammer"/"others"
47
 
48
  def forward(self, x):
49
  x = self.pool(F.relu(self.conv1(x)))
 
51
  x = x.view(-1, 10 * 29 * 29)
52
  x = F.relu(self.fc1(x))
53
  x = self.fc2(x)
 
54
  return x
55
 
56
 
57
+ # Load the trained model
58
  model = ModifiedLargeNet()
59
  model.load_state_dict(torch.load("modified_large_net.pt", map_location=torch.device("cpu")))
60
  model.eval()
61
 
62
+ # Define image transformation pipeline
63
  transform = transforms.Compose([
64
  transforms.Resize((128, 128)),
65
  transforms.ToTensor(),
66
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
67
  ])
68
 
69
+ # Prediction function
70
  def predict(image):
71
+ # Verify input image is a PIL image
72
+ if not isinstance(image, Image.Image):
73
+ raise ValueError("Input must be a PIL Image.")
74
 
75
+ # Transform and predict
76
+ image = transform(image).unsqueeze(0) # Add batch dimension
77
  with torch.no_grad():
78
  outputs = model(image)
79
  probabilities = torch.softmax(outputs, dim=1).numpy()[0]
80
  classes = ["Rope", "Hammer", "Other"]
81
  return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
82
 
83
+ # Gradio interface
84
  interface = gr.Interface(
85
  fn=predict,
86
+ inputs=gr.Image(type="pil"), # Ensure input is a PIL image
87
+ outputs=gr.Label(num_top_classes=3), # Display top 3 class probabilities
88
  title="Mechanical Tools Classifier",
89
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
90
  )
91
 
92
+ # Launch the interface
93
  if __name__ == "__main__":
94
+ interface.launch(share=True) # Add `share=True` for a public link