Adityak204 commited on
Commit
ea55197
·
1 Parent(s): e927cf1

modifed app

Browse files
Files changed (1) hide show
  1. app.py +52 -26
app.py CHANGED
@@ -72,49 +72,75 @@ class ModelPredictor:
72
  """
73
  Make prediction for a single image
74
  Args:
75
- image: PIL Image or path to image
76
  Returns:
77
  Dictionary of class labels and probabilities
78
  """
79
- if isinstance(image, str):
80
- image = Image.open(image).convert("RGB")
81
- else:
82
- image = Image.fromarray(image).convert("RGB")
 
 
 
 
83
 
84
- image_tensor = self.transform(image).unsqueeze(0)
85
- image_tensor = image_tensor.to(self.device)
 
86
 
87
- with torch.no_grad():
88
- outputs = self.model(image_tensor)
89
- probabilities = torch.nn.functional.softmax(outputs, dim=1)
90
 
91
- # Get top 5 predictions
92
- top_probs, top_indices = torch.topk(probabilities, 5)
 
93
 
94
- # Create results dictionary
95
- results = {}
96
- for prob, idx in zip(top_probs[0], top_indices[0]):
97
- class_name = self.class_labels[str(idx.item())]
98
- results[class_name] = float(prob)
99
 
100
- return results
 
 
 
 
 
 
 
 
 
 
101
 
102
 
103
  # Initialize the predictor
104
- predictor = ModelPredictor(
105
- model_repo="Adityak204/ResNetVision-1K", # Replace with your repo
106
- model_filename="resnet50-epoch36-acc60.3506.ckpt", # Replace with your model filename
107
- )
 
 
 
108
 
109
 
110
  def predict_image(image):
111
  """
112
  Gradio interface function
 
 
 
 
113
  """
114
- predictions = predictor.predict(image)
115
-
116
- # Format results for display
117
- return {k: f"{v:.2%}" for k, v in predictions.items()}
 
 
 
 
 
 
118
 
119
 
120
  # Create Gradio interface
 
72
  """
73
  Make prediction for a single image
74
  Args:
75
+ image: numpy array from Gradio
76
  Returns:
77
  Dictionary of class labels and probabilities
78
  """
79
+ try:
80
+ # Convert numpy array to PIL Image
81
+ if isinstance(image, np.ndarray):
82
+ # If image is from Gradio, it will be a numpy array
83
+ image = Image.fromarray(image.astype("uint8"))
84
+ elif isinstance(image, str):
85
+ # If image is a file path
86
+ image = Image.open(image)
87
 
88
+ # Ensure image is in RGB mode
89
+ if image.mode != "RGB":
90
+ image = image.convert("RGB")
91
 
92
+ # Apply transforms and predict
93
+ image_tensor = self.transform(image).unsqueeze(0)
94
+ image_tensor = image_tensor.to(self.device)
95
 
96
+ with torch.no_grad():
97
+ outputs = self.model(image_tensor)
98
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
99
 
100
+ # Get top 5 predictions
101
+ top_probs, top_indices = torch.topk(probabilities, 5)
 
 
 
102
 
103
+ # Create results dictionary
104
+ results = {}
105
+ for prob, idx in zip(top_probs[0], top_indices[0]):
106
+ class_name = self.class_labels[str(idx.item())]
107
+ results[class_name] = float(prob)
108
+
109
+ return results
110
+
111
+ except Exception as e:
112
+ print(f"Error in prediction: {str(e)}")
113
+ return {"error": 1.0}
114
 
115
 
116
  # Initialize the predictor
117
+ try:
118
+ predictor = ModelPredictor(
119
+ model_repo="Adityak204/ResNetVision-1K", # Replace with your repo
120
+ model_filename="resnet50-epoch36-acc60.3506.ckpt", # Replace with your model filename
121
+ )
122
+ except Exception as e:
123
+ print(f"Error initializing predictor: {str(e)}")
124
 
125
 
126
  def predict_image(image):
127
  """
128
  Gradio interface function
129
+ Args:
130
+ image: numpy array from Gradio's image input
131
+ Returns:
132
+ Dictionary of predictions formatted for display
133
  """
134
+ if image is None:
135
+ return {"Error: No image provided": 1.0}
136
+
137
+ try:
138
+ predictions = predictor.predict(image)
139
+ # Format results for display
140
+ return {k: f"{v:.2%}" for k, v in predictions.items()}
141
+ except Exception as e:
142
+ print(f"Error in predict_image: {str(e)}")
143
+ return {"Error: Failed to process image": 1.0}
144
 
145
 
146
  # Create Gradio interface