StoneSeller commited on
Commit
970edfa
·
verified ·
1 Parent(s): d06517d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -46
app.py CHANGED
@@ -6,36 +6,33 @@ import os
6
  def install(package):
7
  subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", package])
8
 
9
- # Ensure NumPy and pandas are compatible
10
  try:
11
  import numpy as np
12
- import pandas as pd
13
- if not (np.__version__.startswith("1.24")):
14
- print(f"Detected incompatible versions. Reinstalling NumPy...")
15
  install("numpy==1.24.3")
16
  except ImportError:
17
- print("NumPy or pandas not found. Installing compatible versions...")
18
  install("numpy==1.24.3")
19
 
20
- # Ensure other dependencies are installed with specific versions
21
- try:
22
- import torch
23
- import torchvision
24
- except ImportError:
25
- install("torch==2.0.1")
26
- install("torchvision==0.15.2")
27
 
28
- try:
29
- from PIL import Image
30
- except ImportError:
31
- install("Pillow==9.5.0")
32
-
33
- try:
34
- import gradio as gr
35
- except ImportError:
36
- install("gradio==3.50.2")
37
 
38
- # Import libraries after ensuring installations
 
39
  import torch
40
  import torch.nn as nn
41
  import torch.nn.functional as F
@@ -52,7 +49,7 @@ class ModifiedLargeNet(nn.Module):
52
  self.pool = nn.MaxPool2d(2, 2)
53
  self.conv2 = nn.Conv2d(5, 10, 5)
54
  self.fc1 = nn.Linear(10 * 29 * 29, 32)
55
- self.fc2 = nn.Linear(32, 3) # classify into "Rope"/"Hammer"/"others"
56
 
57
  def forward(self, x):
58
  x = self.pool(F.relu(self.conv1(x)))
@@ -74,42 +71,50 @@ transform = transforms.Compose([
74
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
75
  ])
76
 
77
- # Prediction function
78
- def predict(image):
79
  if image is None:
80
- raise ValueError("Please provide an image")
81
-
82
- # Convert to PIL Image if necessary
83
- if not isinstance(image, Image.Image):
84
- try:
85
- image = Image.fromarray(image)
86
- except Exception as e:
87
- raise ValueError(f"Failed to convert input to PIL Image: {str(e)}")
88
 
89
- # Transform and predict
 
 
 
90
  try:
91
- image = transform(image).unsqueeze(0) # Add batch dimension
 
 
 
 
 
 
 
 
92
  with torch.no_grad():
93
- outputs = model(image)
94
- probabilities = torch.softmax(outputs, dim=1).numpy()[0]
95
- classes = ["Rope", "Hammer", "Other"]
96
- return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
 
 
 
97
  except Exception as e:
98
- raise ValueError(f"Error during prediction: {str(e)}")
 
99
 
100
  # Gradio interface
101
  interface = gr.Interface(
102
  fn=predict,
103
- inputs=gr.Image(), # Remove type="pil" constraint
104
  outputs=gr.Label(num_top_classes=3),
105
  title="Mechanical Tools Classifier",
106
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
107
- examples=[
108
- ["example_rope.jpg"],
109
- ["example_hammer.jpg"],
110
- ] if os.path.exists("example_rope.jpg") else None # Optional examples
111
  )
112
 
113
  # Launch the interface
114
  if __name__ == "__main__":
115
- interface.launch() # Removed share=True for Hugging Face Spaces
 
6
  def install(package):
7
  subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", package])
8
 
9
+ # First, ensure NumPy is installed with the correct version
10
  try:
11
  import numpy as np
12
+ if not np.__version__.startswith("1.24"):
13
+ print("Installing compatible NumPy version...")
 
14
  install("numpy==1.24.3")
15
  except ImportError:
16
+ print("NumPy not found. Installing...")
17
  install("numpy==1.24.3")
18
 
19
+ # Then install other dependencies
20
+ packages = {
21
+ "torch": "2.0.1",
22
+ "torchvision": "0.15.2",
23
+ "Pillow": "9.5.0",
24
+ "gradio": "3.50.2"
25
+ }
26
 
27
+ for package, version in packages.items():
28
+ try:
29
+ __import__(package.lower())
30
+ except ImportError:
31
+ print(f"Installing {package}...")
32
+ install(f"{package}=={version}")
 
 
 
33
 
34
+ # Import all required libraries
35
+ import numpy as np
36
  import torch
37
  import torch.nn as nn
38
  import torch.nn.functional as F
 
49
  self.pool = nn.MaxPool2d(2, 2)
50
  self.conv2 = nn.Conv2d(5, 10, 5)
51
  self.fc1 = nn.Linear(10 * 29 * 29, 32)
52
+ self.fc2 = nn.Linear(32, 3)
53
 
54
  def forward(self, x):
55
  x = self.pool(F.relu(self.conv1(x)))
 
71
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
72
  ])
73
 
74
+ def process_image(image):
 
75
  if image is None:
76
+ return None
77
+
78
+ # Convert to RGB if necessary
79
+ if image.mode != 'RGB':
80
+ image = image.convert('RGB')
81
+ return image
 
 
82
 
83
+ def predict(image):
84
+ if image is None:
85
+ return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
86
+
87
  try:
88
+ # Process the image
89
+ processed_image = process_image(Image.fromarray(image))
90
+ if processed_image is None:
91
+ return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
92
+
93
+ # Transform for model
94
+ tensor_image = transform(processed_image).unsqueeze(0)
95
+
96
+ # Make prediction
97
  with torch.no_grad():
98
+ outputs = model(tensor_image)
99
+ probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy()
100
+
101
+ # Return results
102
+ classes = ["Rope", "Hammer", "Other"]
103
+ return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
104
+
105
  except Exception as e:
106
+ print(f"Prediction error: {str(e)}")
107
+ return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
108
 
109
  # Gradio interface
110
  interface = gr.Interface(
111
  fn=predict,
112
+ inputs=gr.Image(type="numpy"),
113
  outputs=gr.Label(num_top_classes=3),
114
  title="Mechanical Tools Classifier",
115
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
 
 
 
 
116
  )
117
 
118
  # Launch the interface
119
  if __name__ == "__main__":
120
+ interface.launch()