File size: 11,380 Bytes
1cdf7ce
 
 
0dd123c
018212f
1cdf7ce
e90baee
1cdf7ce
 
e90baee
c480089
7e8e363
cf53d83
c480089
 
9b65193
c480089
 
a6dcdbd
0dd123c
c7a13bc
 
 
 
0dd123c
7e8e363
0dd123c
a6dcdbd
 
 
0dd123c
7e8e363
60ccc8b
0dd123c
 
60ccc8b
7e8e363
 
9b65193
e90baee
 
 
 
 
 
 
 
 
 
 
6efc023
9b65193
6efc023
9b65193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c480089
 
1cdf7ce
 
e90baee
 
 
 
 
0dd123c
018212f
 
 
 
 
 
 
e90baee
 
 
c480089
 
 
 
0dd123c
e90baee
c480089
 
 
e90baee
c7a13bc
4a3ef8c
0dd123c
c480089
0dd123c
c480089
 
1cdf7ce
 
 
 
 
7e8e363
c480089
 
 
 
1cdf7ce
0dd123c
c480089
 
 
 
 
 
 
 
 
 
1cdf7ce
 
c480089
 
 
1cdf7ce
c480089
9b65193
c480089
 
1cdf7ce
c480089
 
 
0dd123c
c480089
cf53d83
0dd123c
cf53d83
c480089
69da21c
c480089
 
018212f
c480089
0dd123c
6f2771b
 
9b65193
018212f
 
0dd123c
018212f
cf53d83
 
018212f
 
6f2771b
 
 
9b65193
c480089
9b65193
cf53d83
c480089
018212f
6d7baf2
018212f
 
cf53d83
6efc023
60ccc8b
cf53d83
 
 
 
c480089
 
cf53d83
 
9b65193
 
c480089
6f2771b
 
 
cf53d83
69da21c
4a3ef8c
69da21c
56ce19c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a3ef8c
1cdf7ce
e90baee
cf53d83
c7a13bc
cf53d83
 
 
c480089
1cdf7ce
 
c480089
1cdf7ce
 
018212f
cf53d83
 
c7a13bc
cf53d83
60ccc8b
cf53d83
 
 
0dd123c
c480089
1cdf7ce
c480089
 
cf53d83
c480089
 
1cdf7ce
 
c480089
018212f
c480089
1cdf7ce
 
c480089
1cdf7ce
018212f
1cdf7ce
6d7baf2
 
a6dcdbd
50014ba
c480089
 
7c71cf4
1cdf7ce
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import gradio as gr
import numpy as np
import supervision as sv
import torch
import cv2
from PIL import Image
import lightly_train

# --- CONFIGURATION ---

MARKDOWN_HEADER = """
# LightlyTrain Detection & Segmentation Demo 🚀
[GitHub](https://github.com/lightly-ai/lightly-train) | [Documentation](https://docs.lightly.ai/train)

This demo showcases **LightlyTrain**, a powerful library for self-supervised learning and fine-tuning.
Uses **DINOv3** backbones to detect objects or segment scenes (**COCO Classes**).
"""

# DEFINE MODELS
DETECTION_MODELS = [
    "dinov3/vitt16-ltdetr-coco",
    "dinov3/convnext-base-ltdetr-coco",
    "dinov3/convnext-small-ltdetr-coco",
    "dinov3/convnext-tiny-ltdetr-coco"
]

SEGMENTATION_MODELS = [
    "dinov3/vitb16-eomt-coco",       
    "dinov3/vitl16-eomt-coco",       
    "dinov3/vits16-eomt-coco"       
]

ALL_MODELS = DETECTION_MODELS + SEGMENTATION_MODELS
DEFAULT_MODEL = DETECTION_MODELS[0]

# 2. CLASS LISTS

# COCO Detection (80 Classes)
COCO_DETECTION_CLASSES = [
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
    "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
    "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
    "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
    "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
    "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
    "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard",
    "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
    "scissors", "teddy bear", "hair drier", "toothbrush"
]

# COCO-Stuff (171 Classes)
COCO_STUFF_CLASSES = [
    "unlabeled", # Index 0 (Background)
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
    "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
    "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
    "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
    "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
    "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
    "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard",
    "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
    "scissors", "teddy bear", "hair drier", "toothbrush", "banner", "blanket", "branch", "bridge", "building-other", 
    "bush", "cabinet", "cage", "cardboard", "carpet", "ceiling-other", "ceiling-tile", "cloth", "clothes", "clouds", 
    "counter", "cupboard", "curtain", "desk-stuff", "dirt", "door-stuff", "fence", "floor-marble", "floor-other", 
    "floor-stone", "floor-tile", "floor-wood", "flower", "fog", "food-other", "fruit", "furniture-other", "grass", 
    "gravel", "ground-other", "hill", "house", "leaves", "light", "mat", "metal", "mirror-stuff", "moss", "mountain", 
    "mud", "napkin", "net", "paper", "pavement", "pillow", "plant-other", "plastic", "platform", "playingfield", 
    "railing", "railroad", "river", "road", "rock", "roof", "rug", "salad", "sand", "sea", "shelf", "sky-other", 
    "skyscraper", "snow", "solid-other", "stairs", "stone", "straw", "structural-other", "table", "tent", "textile-other", 
    "towel", "tree", "vegetable", "wall-brick", "wall-concrete", "wall-other", "wall-panel", "wall-stone", "wall-tile", 
    "wall-wood", "water-other", "waterdrops", "window-blind", "window-other", "wood"
]

# --- HELPER FUNCTIONS ---

loaded_models = {}

def get_model(model_name):
    if model_name in loaded_models:
        return loaded_models[model_name]
    print(f"Loading model: {model_name}...")
    try:
        model = lightly_train.load_model(model_name)
        loaded_models[model_name] = model
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

get_model(DEFAULT_MODEL)

# --- INFERENCE LOGIC ---

def run_prediction(image, confidence_threshold, resolution, model_name):
    if image is None: return None, None, None
    
    model = get_model(model_name)
    if model is None: return image, "Error loading model", {}

    image_input = image.resize((resolution, resolution))
    
    if model_name in SEGMENTATION_MODELS:
        return run_segmentation(model, image_input, image)
    else:
        return run_detection(model, image_input, image, confidence_threshold)

def run_detection(model, image_input, original_image, confidence_threshold):
    results = model.predict(image_input)
    
    boxes = results['bboxes'].cpu().numpy()
    labels = results['labels'].cpu().numpy()
    scores = results['scores'].cpu().numpy()

    # Filter
    valid = scores > confidence_threshold
    boxes = boxes[valid]
    labels = labels[valid]
    scores = scores[valid]

    detections = sv.Detections(xyxy=boxes, confidence=scores, class_id=labels)
    
    w_input, h_input = image_input.size
    w_orig, h_orig = original_image.size
    scale_x, scale_y = w_orig / w_input, h_orig / h_input
    
    detections.xyxy[:, 0] *= scale_x
    detections.xyxy[:, 1] *= scale_y
    detections.xyxy[:, 2] *= scale_x
    detections.xyxy[:, 3] *= scale_y

    box_annotator = sv.BoxAnnotator()
    label_annotator = sv.LabelAnnotator()
    
    labels_text = []
    class_counts = {}

    for cid, conf in zip(detections.class_id, detections.confidence):
        name = COCO_DETECTION_CLASSES[cid] if cid < len(COCO_DETECTION_CLASSES) else f"Class {cid}"
        labels_text.append(f"{name} {conf:.2f}")
        class_counts[name] = class_counts.get(name, 0) + 1

    annotated = original_image.copy()
    annotated = box_annotator.annotate(scene=annotated, detections=detections)
    annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels_text)
    
    summary_list = [f"{k}: {v}" for k, v in class_counts.items()]
    analytics_text = "Objects Found (Detection):\n" + (", ".join(summary_list) if summary_list else "None")
    
    return annotated, analytics_text, {"count": len(boxes), "objects": class_counts}

# Segm code opt 1
def run_segmentation(model, image_input, original_image):
    mask_tensor = model.predict(image_input)
    mask_np = mask_tensor.cpu().numpy().astype(np.uint8)
    mask_np = cv2.resize(mask_np, original_image.size, interpolation=cv2.INTER_NEAREST)

    #current_classes = COCO_STUFF_CLASSES
    current_classes = ["unlabeled"] + COCO_DETECTION_CLASSES

    h, w = mask_np.shape
    colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
    
    unique_classes = np.unique(mask_np)
    found_classes = set()
    labels_to_draw = []
    
    for cls_id in unique_classes:
        # Skip background (Index 0)
        if cls_id == 0 or cls_id == 255: continue

        if cls_id < 0 or cls_id >= len(current_classes): continue
        
        class_name = current_classes[cls_id]
        found_classes.add(class_name)
        
        np.random.seed(int(cls_id)) 
        color = np.random.randint(50, 255, size=3) 
        colored_mask[mask_np == cls_id] = color

        y_indices, x_indices = np.where(mask_np == cls_id)
        # Filter small noise
        if len(y_indices) > 200: 
            centroid_y = int(np.mean(y_indices))
            centroid_x = int(np.mean(x_indices))
            labels_to_draw.append((centroid_x, centroid_y, class_name))

    original_np = np.array(original_image)
    blended = cv2.addWeighted(original_np, 0.6, colored_mask, 0.4, 0)

    for (cx, cy, text) in labels_to_draw:
        cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 3, cv2.LINE_AA)
        cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA)
    
    analytics_text = f"Scene Contains (COCO Objects):\n" + (", ".join(sorted(list(found_classes))) if found_classes else "None")


    return Image.fromarray(blended), analytics_text, {"classes_found": list(found_classes)}

'''
# Segm code opt 2
def run_segmentation(model, image):
    """
    Handles Segmentation: Returns Tensor of shape (H, W) with class IDs.
    """
    mask_tensor = model.predict(image)
    mask_np = mask_tensor.cpu().numpy().astype(np.uint8)

    h, w = mask_np.shape
    colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
    
    unique_classes = np.unique(mask_np)
    
    for cls_id in unique_classes:
        if cls_id == -1: continue 
        
        np.random.seed(int(cls_id)) 
        color = np.random.randint(50, 255, size=3) 
        colored_mask[mask_np == cls_id] = color

    image_np = np.array(image)
    if image_np.shape[:2] != colored_mask.shape[:2]:
        colored_mask = cv2.resize(colored_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)

    blended = cv2.addWeighted(image_np, 0.6, colored_mask, 0.4, 0)
    return Image.fromarray(blended)

'''
# --- GRADIO UI ---

theme = gr.themes.Soft(
    font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
)

with gr.Blocks(theme=theme) as demo:
    gr.Markdown(MARKDOWN_HEADER)
    
    with gr.Row():
        with gr.Column(scale=1):
            input_img = gr.Image(type="pil", label="Input Image")
            
            with gr.Accordion("Settings", open=True):
                conf_slider = gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="Confidence (Detection Only)")
                res_slider = gr.Slider(384, 1024, value=640, step=32, label="Inference Resolution")
                
                model_selector = gr.Dropdown(
                    choices=ALL_MODELS,
                    value=DEFAULT_MODEL,
                    label="Model Checkpoint"
                )

            run_btn = gr.Button("Analyze Image", variant="primary")
        
        with gr.Column(scale=1):
            output_img = gr.Image(label="Annotated Result")
            output_text = gr.Textbox(label="Analytics Summary", interactive=False, lines=6)
            with gr.Accordion("Raw Data (JSON)", open=False):
                output_json = gr.JSON(label="Detection Data")

    run_btn.click(
        fn=run_prediction, 
        inputs=[input_img, conf_slider, res_slider, model_selector], 
        outputs=[output_img, output_text, output_json]
    )

    gr.Markdown("### 💡 Try an Example")
    gr.Examples(
        inputs=[input_img, conf_slider, res_slider, model_selector],
        examples=[
            ["http://farm3.staticflickr.com/2547/3933456087_6a4dfb4736_z.jpg", 0.4, 640, DEFAULT_MODEL],
            ["https://farm3.staticflickr.com/2294/2193565429_aed7c9ff98_z.jpg", 0.4, 640, DEFAULT_MODEL],
            ["http://farm9.staticflickr.com/8092/8400332884_102a62b6c6_z.jpg", 0.4, 640, "dinov3/vitb16-eomt-coco"],
           ],
        outputs=[output_img, output_text, output_json],
        fn=run_prediction,
        cache_examples=False,
    )

if __name__ == "__main__":
    demo.launch()