Spaces:
Sleeping
Sleeping
stich to base and large segm models
Browse files
app.py
CHANGED
|
@@ -9,7 +9,7 @@ import lightly_train
|
|
| 9 |
# --- CONFIGURATION ---
|
| 10 |
|
| 11 |
MARKDOWN_HEADER = """
|
| 12 |
-
# LightlyTrain
|
| 13 |
[GitHub](https://github.com/lightly-ai/lightly-train) | [Documentation](https://docs.lightly.ai/train)
|
| 14 |
|
| 15 |
This demo showcases **LightlyTrain**, a powerful library for self-supervised learning and fine-tuning.
|
|
@@ -23,14 +23,20 @@ DETECTION_MODELS = [
|
|
| 23 |
"dinov3/convnext-small-ltdetr-coco",
|
| 24 |
"dinov3/convnext-tiny-ltdetr-coco"
|
| 25 |
]
|
|
|
|
|
|
|
| 26 |
SEGMENTATION_MODELS = [
|
| 27 |
-
"dinov3/
|
|
|
|
|
|
|
| 28 |
]
|
|
|
|
| 29 |
ALL_MODELS = DETECTION_MODELS + SEGMENTATION_MODELS
|
| 30 |
DEFAULT_MODEL = DETECTION_MODELS[0]
|
| 31 |
|
| 32 |
# 2. CLASS LISTS
|
| 33 |
-
|
|
|
|
| 34 |
COCO_DETECTION_CLASSES = [
|
| 35 |
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
| 36 |
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
|
|
@@ -43,8 +49,7 @@ COCO_DETECTION_CLASSES = [
|
|
| 43 |
"scissors", "teddy bear", "hair drier", "toothbrush"
|
| 44 |
]
|
| 45 |
|
| 46 |
-
# COCO-Stuff (171
|
| 47 |
-
# Includes the 80 "things" above + 91 "stuff" classes (sky, road, etc.)
|
| 48 |
COCO_STUFF_CLASSES = [
|
| 49 |
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
| 50 |
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
|
|
@@ -106,6 +111,7 @@ def run_detection(model, image_input, original_image, confidence_threshold):
|
|
| 106 |
labels = results['labels'].cpu().numpy()
|
| 107 |
scores = results['scores'].cpu().numpy()
|
| 108 |
|
|
|
|
| 109 |
valid = scores > confidence_threshold
|
| 110 |
boxes = boxes[valid]
|
| 111 |
labels = labels[valid]
|
|
@@ -147,7 +153,6 @@ def run_segmentation(model, image_input, original_image):
|
|
| 147 |
mask_np = mask_tensor.cpu().numpy().astype(np.uint8)
|
| 148 |
mask_np = cv2.resize(mask_np, original_image.size, interpolation=cv2.INTER_NEAREST)
|
| 149 |
|
| 150 |
-
# Use COCO-Stuff classes
|
| 151 |
current_classes = COCO_STUFF_CLASSES
|
| 152 |
|
| 153 |
h, w = mask_np.shape
|
|
@@ -158,10 +163,10 @@ def run_segmentation(model, image_input, original_image):
|
|
| 158 |
labels_to_draw = []
|
| 159 |
|
| 160 |
for cls_id in unique_classes:
|
| 161 |
-
# Safety check: skip 'background' class
|
| 162 |
if cls_id == 255 or cls_id == -1: continue
|
| 163 |
|
| 164 |
-
#
|
| 165 |
if cls_id < 0 or cls_id >= len(current_classes): continue
|
| 166 |
|
| 167 |
class_name = current_classes[cls_id]
|
|
@@ -184,7 +189,7 @@ def run_segmentation(model, image_input, original_image):
|
|
| 184 |
cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 3, cv2.LINE_AA)
|
| 185 |
cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA)
|
| 186 |
|
| 187 |
-
analytics_text = f"Scene Contains (
|
| 188 |
|
| 189 |
return Image.fromarray(blended), analytics_text, {"classes_found": list(found_classes)}
|
| 190 |
|
|
@@ -231,7 +236,7 @@ with gr.Blocks(theme=theme) as demo:
|
|
| 231 |
examples=[
|
| 232 |
["http://farm3.staticflickr.com/2547/3933456087_6a4dfb4736_z.jpg", 0.4, 640, DEFAULT_MODEL],
|
| 233 |
["https://farm3.staticflickr.com/2294/2193565429_aed7c9ff98_z.jpg", 0.4, 640, DEFAULT_MODEL],
|
| 234 |
-
["
|
| 235 |
],
|
| 236 |
outputs=[output_img, output_text, output_json],
|
| 237 |
fn=run_prediction,
|
|
|
|
| 9 |
# --- CONFIGURATION ---
|
| 10 |
|
| 11 |
MARKDOWN_HEADER = """
|
| 12 |
+
# LightlyTrain Detection & Segmentation Demo 🚀
|
| 13 |
[GitHub](https://github.com/lightly-ai/lightly-train) | [Documentation](https://docs.lightly.ai/train)
|
| 14 |
|
| 15 |
This demo showcases **LightlyTrain**, a powerful library for self-supervised learning and fine-tuning.
|
|
|
|
| 23 |
"dinov3/convnext-small-ltdetr-coco",
|
| 24 |
"dinov3/convnext-tiny-ltdetr-coco"
|
| 25 |
]
|
| 26 |
+
|
| 27 |
+
# UPDATED: Added Base (vitb16) and Large (vitl16) for better accuracy
|
| 28 |
SEGMENTATION_MODELS = [
|
| 29 |
+
"dinov3/vitb16-eomt-coco", # Base (Recommended Balance)
|
| 30 |
+
"dinov3/vitl16-eomt-coco", # Large (Best Accuracy, Slower)
|
| 31 |
+
"dinov3/vits16-eomt-coco" # Small (Fastest)
|
| 32 |
]
|
| 33 |
+
|
| 34 |
ALL_MODELS = DETECTION_MODELS + SEGMENTATION_MODELS
|
| 35 |
DEFAULT_MODEL = DETECTION_MODELS[0]
|
| 36 |
|
| 37 |
# 2. CLASS LISTS
|
| 38 |
+
|
| 39 |
+
# COCO Detection (80 Classes)
|
| 40 |
COCO_DETECTION_CLASSES = [
|
| 41 |
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
| 42 |
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
|
|
|
|
| 49 |
"scissors", "teddy bear", "hair drier", "toothbrush"
|
| 50 |
]
|
| 51 |
|
| 52 |
+
# COCO-Stuff (171 Classes) - Standard Mapping
|
|
|
|
| 53 |
COCO_STUFF_CLASSES = [
|
| 54 |
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
| 55 |
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
|
|
|
|
| 111 |
labels = results['labels'].cpu().numpy()
|
| 112 |
scores = results['scores'].cpu().numpy()
|
| 113 |
|
| 114 |
+
# Filter
|
| 115 |
valid = scores > confidence_threshold
|
| 116 |
boxes = boxes[valid]
|
| 117 |
labels = labels[valid]
|
|
|
|
| 153 |
mask_np = mask_tensor.cpu().numpy().astype(np.uint8)
|
| 154 |
mask_np = cv2.resize(mask_np, original_image.size, interpolation=cv2.INTER_NEAREST)
|
| 155 |
|
|
|
|
| 156 |
current_classes = COCO_STUFF_CLASSES
|
| 157 |
|
| 158 |
h, w = mask_np.shape
|
|
|
|
| 163 |
labels_to_draw = []
|
| 164 |
|
| 165 |
for cls_id in unique_classes:
|
| 166 |
+
# Safety check: skip 'background' class
|
| 167 |
if cls_id == 255 or cls_id == -1: continue
|
| 168 |
|
| 169 |
+
# COCO-Stuff mapping
|
| 170 |
if cls_id < 0 or cls_id >= len(current_classes): continue
|
| 171 |
|
| 172 |
class_name = current_classes[cls_id]
|
|
|
|
| 189 |
cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 3, cv2.LINE_AA)
|
| 190 |
cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA)
|
| 191 |
|
| 192 |
+
analytics_text = f"Scene Contains (Segmentation):\n" + (", ".join(sorted(list(found_classes))) if found_classes else "None")
|
| 193 |
|
| 194 |
return Image.fromarray(blended), analytics_text, {"classes_found": list(found_classes)}
|
| 195 |
|
|
|
|
| 236 |
examples=[
|
| 237 |
["http://farm3.staticflickr.com/2547/3933456087_6a4dfb4736_z.jpg", 0.4, 640, DEFAULT_MODEL],
|
| 238 |
["https://farm3.staticflickr.com/2294/2193565429_aed7c9ff98_z.jpg", 0.4, 640, DEFAULT_MODEL],
|
| 239 |
+
["http://farm9.staticflickr.com/8092/8400332884_102a62b6c6_z.jpg", 0.6, 640, "dinov3/vits16-eomt-ade20k"],
|
| 240 |
],
|
| 241 |
outputs=[output_img, output_text, output_json],
|
| 242 |
fn=run_prediction,
|