masakljun commited on
Commit
7e8e363
·
1 Parent(s): 9f1fe9a

stich to base and large segm models

Browse files
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -9,7 +9,7 @@ import lightly_train
9
  # --- CONFIGURATION ---
10
 
11
  MARKDOWN_HEADER = """
12
- # LightlyTrain Object Detection Demo 🚀
13
  [GitHub](https://github.com/lightly-ai/lightly-train) | [Documentation](https://docs.lightly.ai/train)
14
 
15
  This demo showcases **LightlyTrain**, a powerful library for self-supervised learning and fine-tuning.
@@ -23,14 +23,20 @@ DETECTION_MODELS = [
23
  "dinov3/convnext-small-ltdetr-coco",
24
  "dinov3/convnext-tiny-ltdetr-coco"
25
  ]
 
 
26
  SEGMENTATION_MODELS = [
27
- "dinov3/vits16-eomt-coco" # COCO-Stuff (171 Classes)
 
 
28
  ]
 
29
  ALL_MODELS = DETECTION_MODELS + SEGMENTATION_MODELS
30
  DEFAULT_MODEL = DETECTION_MODELS[0]
31
 
32
  # 2. CLASS LISTS
33
- # Standard COCO "Things" (80 classes) for Detection
 
34
  COCO_DETECTION_CLASSES = [
35
  "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
36
  "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
@@ -43,8 +49,7 @@ COCO_DETECTION_CLASSES = [
43
  "scissors", "teddy bear", "hair drier", "toothbrush"
44
  ]
45
 
46
- # COCO-Stuff (171 classes) for Segmentation
47
- # Includes the 80 "things" above + 91 "stuff" classes (sky, road, etc.)
48
  COCO_STUFF_CLASSES = [
49
  "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
50
  "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
@@ -106,6 +111,7 @@ def run_detection(model, image_input, original_image, confidence_threshold):
106
  labels = results['labels'].cpu().numpy()
107
  scores = results['scores'].cpu().numpy()
108
 
 
109
  valid = scores > confidence_threshold
110
  boxes = boxes[valid]
111
  labels = labels[valid]
@@ -147,7 +153,6 @@ def run_segmentation(model, image_input, original_image):
147
  mask_np = mask_tensor.cpu().numpy().astype(np.uint8)
148
  mask_np = cv2.resize(mask_np, original_image.size, interpolation=cv2.INTER_NEAREST)
149
 
150
- # Use COCO-Stuff classes
151
  current_classes = COCO_STUFF_CLASSES
152
 
153
  h, w = mask_np.shape
@@ -158,10 +163,10 @@ def run_segmentation(model, image_input, original_image):
158
  labels_to_draw = []
159
 
160
  for cls_id in unique_classes:
161
- # Safety check: skip 'background' class (often 255 or -1)
162
  if cls_id == 255 or cls_id == -1: continue
163
 
164
- # Standard COCO-Stuff mapping: 0-170
165
  if cls_id < 0 or cls_id >= len(current_classes): continue
166
 
167
  class_name = current_classes[cls_id]
@@ -184,7 +189,7 @@ def run_segmentation(model, image_input, original_image):
184
  cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 3, cv2.LINE_AA)
185
  cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA)
186
 
187
- analytics_text = f"Scene Contains (COCO):\n" + (", ".join(sorted(list(found_classes))) if found_classes else "None")
188
 
189
  return Image.fromarray(blended), analytics_text, {"classes_found": list(found_classes)}
190
 
@@ -231,7 +236,7 @@ with gr.Blocks(theme=theme) as demo:
231
  examples=[
232
  ["http://farm3.staticflickr.com/2547/3933456087_6a4dfb4736_z.jpg", 0.4, 640, DEFAULT_MODEL],
233
  ["https://farm3.staticflickr.com/2294/2193565429_aed7c9ff98_z.jpg", 0.4, 640, DEFAULT_MODEL],
234
- ["https://farm3.staticflickr.com/2294/2193565429_aed7c9ff98_z.jpg", 0.4, 512, "dinov3/vits16-eomt-coco"],
235
  ],
236
  outputs=[output_img, output_text, output_json],
237
  fn=run_prediction,
 
9
  # --- CONFIGURATION ---
10
 
11
  MARKDOWN_HEADER = """
12
+ # LightlyTrain Detection & Segmentation Demo 🚀
13
  [GitHub](https://github.com/lightly-ai/lightly-train) | [Documentation](https://docs.lightly.ai/train)
14
 
15
  This demo showcases **LightlyTrain**, a powerful library for self-supervised learning and fine-tuning.
 
23
  "dinov3/convnext-small-ltdetr-coco",
24
  "dinov3/convnext-tiny-ltdetr-coco"
25
  ]
26
+
27
+ # UPDATED: Added Base (vitb16) and Large (vitl16) for better accuracy
28
  SEGMENTATION_MODELS = [
29
+ "dinov3/vitb16-eomt-coco", # Base (Recommended Balance)
30
+ "dinov3/vitl16-eomt-coco", # Large (Best Accuracy, Slower)
31
+ "dinov3/vits16-eomt-coco" # Small (Fastest)
32
  ]
33
+
34
  ALL_MODELS = DETECTION_MODELS + SEGMENTATION_MODELS
35
  DEFAULT_MODEL = DETECTION_MODELS[0]
36
 
37
  # 2. CLASS LISTS
38
+
39
+ # COCO Detection (80 Classes)
40
  COCO_DETECTION_CLASSES = [
41
  "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
42
  "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
 
49
  "scissors", "teddy bear", "hair drier", "toothbrush"
50
  ]
51
 
52
+ # COCO-Stuff (171 Classes) - Standard Mapping
 
53
  COCO_STUFF_CLASSES = [
54
  "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
55
  "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
 
111
  labels = results['labels'].cpu().numpy()
112
  scores = results['scores'].cpu().numpy()
113
 
114
+ # Filter
115
  valid = scores > confidence_threshold
116
  boxes = boxes[valid]
117
  labels = labels[valid]
 
153
  mask_np = mask_tensor.cpu().numpy().astype(np.uint8)
154
  mask_np = cv2.resize(mask_np, original_image.size, interpolation=cv2.INTER_NEAREST)
155
 
 
156
  current_classes = COCO_STUFF_CLASSES
157
 
158
  h, w = mask_np.shape
 
163
  labels_to_draw = []
164
 
165
  for cls_id in unique_classes:
166
+ # Safety check: skip 'background' class
167
  if cls_id == 255 or cls_id == -1: continue
168
 
169
+ # COCO-Stuff mapping
170
  if cls_id < 0 or cls_id >= len(current_classes): continue
171
 
172
  class_name = current_classes[cls_id]
 
189
  cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 3, cv2.LINE_AA)
190
  cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA)
191
 
192
+ analytics_text = f"Scene Contains (Segmentation):\n" + (", ".join(sorted(list(found_classes))) if found_classes else "None")
193
 
194
  return Image.fromarray(blended), analytics_text, {"classes_found": list(found_classes)}
195
 
 
236
  examples=[
237
  ["http://farm3.staticflickr.com/2547/3933456087_6a4dfb4736_z.jpg", 0.4, 640, DEFAULT_MODEL],
238
  ["https://farm3.staticflickr.com/2294/2193565429_aed7c9ff98_z.jpg", 0.4, 640, DEFAULT_MODEL],
239
+ ["http://farm9.staticflickr.com/8092/8400332884_102a62b6c6_z.jpg", 0.6, 640, "dinov3/vits16-eomt-ade20k"],
240
  ],
241
  outputs=[output_img, output_text, output_json],
242
  fn=run_prediction,