masakljun commited on
Commit
69da21c
·
1 Parent(s): 56ce19c

try to fix segm

Browse files
Files changed (1) hide show
  1. app.py +4 -9
app.py CHANGED
@@ -49,7 +49,6 @@ COCO_DETECTION_CLASSES = [
49
  ]
50
 
51
  # COCO-Stuff (171 Classes)
52
- # FIX: Added 'unlabeled' at index 0 so 'person' aligns with Index 1
53
  COCO_STUFF_CLASSES = [
54
  "unlabeled", # Index 0 (Background)
55
  "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
@@ -101,7 +100,7 @@ def run_prediction(image, confidence_threshold, resolution, model_name):
101
  image_input = image.resize((resolution, resolution))
102
 
103
  if model_name in SEGMENTATION_MODELS:
104
- return run_segmentation(model, image_input, image)
105
  else:
106
  return run_detection(model, image_input, image, confidence_threshold)
107
 
@@ -149,7 +148,7 @@ def run_detection(model, image_input, original_image, confidence_threshold):
149
 
150
  return annotated, analytics_text, {"count": len(boxes), "objects": class_counts}
151
 
152
-
153
  '''
154
  def run_segmentation(model, image_input, original_image):
155
  mask_tensor = model.predict(image_input)
@@ -198,17 +197,15 @@ def run_segmentation(model, image_input, original_image):
198
 
199
  return Image.fromarray(blended), analytics_text, {"classes_found": list(found_classes)}
200
  '''
 
 
201
  def run_segmentation(model, image):
202
  """
203
  Handles Segmentation: Returns Tensor of shape (H, W) with class IDs.
204
  """
205
- # 1. Run Prediction
206
  mask_tensor = model.predict(image)
207
-
208
- # 2. Convert to Numpy
209
  mask_np = mask_tensor.cpu().numpy().astype(np.uint8)
210
 
211
- # 3. Create a Colored Mask
212
  h, w = mask_np.shape
213
  colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
214
 
@@ -221,7 +218,6 @@ def run_segmentation(model, image):
221
  color = np.random.randint(50, 255, size=3)
222
  colored_mask[mask_np == cls_id] = color
223
 
224
- # 4. Blend with Original Image
225
  image_np = np.array(image)
226
  if image_np.shape[:2] != colored_mask.shape[:2]:
227
  colored_mask = cv2.resize(colored_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
@@ -245,7 +241,6 @@ with gr.Blocks(theme=theme) as demo:
245
 
246
  with gr.Accordion("Settings", open=True):
247
  conf_slider = gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="Confidence (Detection Only)")
248
- # BUMPED DEFAULT TO 640 for sharper masks
249
  res_slider = gr.Slider(384, 1024, value=640, step=32, label="Inference Resolution")
250
 
251
  model_selector = gr.Dropdown(
 
49
  ]
50
 
51
  # COCO-Stuff (171 Classes)
 
52
  COCO_STUFF_CLASSES = [
53
  "unlabeled", # Index 0 (Background)
54
  "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
 
100
  image_input = image.resize((resolution, resolution))
101
 
102
  if model_name in SEGMENTATION_MODELS:
103
+ return run_segmentation(model, image_input)
104
  else:
105
  return run_detection(model, image_input, image, confidence_threshold)
106
 
 
148
 
149
  return annotated, analytics_text, {"count": len(boxes), "objects": class_counts}
150
 
151
+ # Segm code opt 1
152
  '''
153
  def run_segmentation(model, image_input, original_image):
154
  mask_tensor = model.predict(image_input)
 
197
 
198
  return Image.fromarray(blended), analytics_text, {"classes_found": list(found_classes)}
199
  '''
200
+
201
+ # Segm code opt 2
202
  def run_segmentation(model, image):
203
  """
204
  Handles Segmentation: Returns Tensor of shape (H, W) with class IDs.
205
  """
 
206
  mask_tensor = model.predict(image)
 
 
207
  mask_np = mask_tensor.cpu().numpy().astype(np.uint8)
208
 
 
209
  h, w = mask_np.shape
210
  colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
211
 
 
218
  color = np.random.randint(50, 255, size=3)
219
  colored_mask[mask_np == cls_id] = color
220
 
 
221
  image_np = np.array(image)
222
  if image_np.shape[:2] != colored_mask.shape[:2]:
223
  colored_mask = cv2.resize(colored_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
 
241
 
242
  with gr.Accordion("Settings", open=True):
243
  conf_slider = gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="Confidence (Detection Only)")
 
244
  res_slider = gr.Slider(384, 1024, value=640, step=32, label="Inference Resolution")
245
 
246
  model_selector = gr.Dropdown(