masakljun commited on
Commit
6efc023
·
1 Parent(s): a2bf58b

fix off mapping

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -24,10 +24,9 @@ DETECTION_MODELS = [
24
  "dinov3/convnext-tiny-ltdetr-coco"
25
  ]
26
 
27
- # UPDATED: Added Base (vitb16) and Large (vitl16) for better accuracy
28
  SEGMENTATION_MODELS = [
29
- "dinov3/vitb16-eomt-coco", # Base (Recommended Balance)
30
- "dinov3/vitl16-eomt-coco", # Large (Best Accuracy, Slower)
31
  "dinov3/vits16-eomt-coco" # Small (Fastest)
32
  ]
33
 
@@ -49,8 +48,10 @@ COCO_DETECTION_CLASSES = [
49
  "scissors", "teddy bear", "hair drier", "toothbrush"
50
  ]
51
 
52
- # COCO-Stuff (171 Classes) - Standard Mapping
 
53
  COCO_STUFF_CLASSES = [
 
54
  "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
55
  "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
56
  "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
@@ -163,10 +164,9 @@ def run_segmentation(model, image_input, original_image):
163
  labels_to_draw = []
164
 
165
  for cls_id in unique_classes:
166
- # Safety check: skip 'background' class
167
- if cls_id == 255 or cls_id == -1: continue
168
-
169
- # COCO-Stuff mapping
170
  if cls_id < 0 or cls_id >= len(current_classes): continue
171
 
172
  class_name = current_classes[cls_id]
@@ -177,6 +177,7 @@ def run_segmentation(model, image_input, original_image):
177
  colored_mask[mask_np == cls_id] = color
178
 
179
  y_indices, x_indices = np.where(mask_np == cls_id)
 
180
  if len(y_indices) > 200:
181
  centroid_y = int(np.mean(y_indices))
182
  centroid_x = int(np.mean(x_indices))
@@ -208,6 +209,7 @@ with gr.Blocks(theme=theme) as demo:
208
 
209
  with gr.Accordion("Settings", open=True):
210
  conf_slider = gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="Confidence (Detection Only)")
 
211
  res_slider = gr.Slider(384, 1024, value=640, step=32, label="Inference Resolution")
212
 
213
  model_selector = gr.Dropdown(
@@ -236,7 +238,7 @@ with gr.Blocks(theme=theme) as demo:
236
  examples=[
237
  ["http://farm3.staticflickr.com/2547/3933456087_6a4dfb4736_z.jpg", 0.4, 640, DEFAULT_MODEL],
238
  ["https://farm3.staticflickr.com/2294/2193565429_aed7c9ff98_z.jpg", 0.4, 640, DEFAULT_MODEL],
239
- ["http://farm9.staticflickr.com/8092/8400332884_102a62b6c6_z.jpg", 0.4, 512, "dinov3/vitl16-eomt-coco"],
240
  ],
241
  outputs=[output_img, output_text, output_json],
242
  fn=run_prediction,
 
24
  "dinov3/convnext-tiny-ltdetr-coco"
25
  ]
26
 
 
27
  SEGMENTATION_MODELS = [
28
+ "dinov3/vitb16-eomt-coco", # Base (Balanced)
29
+ "dinov3/vitl16-eomt-coco", # Large (Best Accuracy)
30
  "dinov3/vits16-eomt-coco" # Small (Fastest)
31
  ]
32
 
 
48
  "scissors", "teddy bear", "hair drier", "toothbrush"
49
  ]
50
 
51
+ # COCO-Stuff (171 Classes)
52
+ # FIX: Added 'unlabeled' at index 0 so 'person' aligns with Index 1
53
  COCO_STUFF_CLASSES = [
54
+ "unlabeled", # Index 0 (Background)
55
  "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
56
  "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
57
  "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
 
164
  labels_to_draw = []
165
 
166
  for cls_id in unique_classes:
167
+ # 0 is usually background/unlabeled in this list
168
+ if cls_id == 0: continue
169
+ # Safety check for range
 
170
  if cls_id < 0 or cls_id >= len(current_classes): continue
171
 
172
  class_name = current_classes[cls_id]
 
177
  colored_mask[mask_np == cls_id] = color
178
 
179
  y_indices, x_indices = np.where(mask_np == cls_id)
180
+ # Filter small noise
181
  if len(y_indices) > 200:
182
  centroid_y = int(np.mean(y_indices))
183
  centroid_x = int(np.mean(x_indices))
 
209
 
210
  with gr.Accordion("Settings", open=True):
211
  conf_slider = gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="Confidence (Detection Only)")
212
+ # BUMPED DEFAULT TO 640 for sharper masks
213
  res_slider = gr.Slider(384, 1024, value=640, step=32, label="Inference Resolution")
214
 
215
  model_selector = gr.Dropdown(
 
238
  examples=[
239
  ["http://farm3.staticflickr.com/2547/3933456087_6a4dfb4736_z.jpg", 0.4, 640, DEFAULT_MODEL],
240
  ["https://farm3.staticflickr.com/2294/2193565429_aed7c9ff98_z.jpg", 0.4, 640, DEFAULT_MODEL],
241
+ ["http://farm9.staticflickr.com/8092/8400332884_102a62b6c6_z.jpg", 0.4, 640, "dinov3/vitn16-eomt-coco"],
242
  ],
243
  outputs=[output_img, output_text, output_json],
244
  fn=run_prediction,