import gradio as gr import numpy as np import supervision as sv import torch import cv2 from PIL import Image import lightly_train # --- CONFIGURATION --- MARKDOWN_HEADER = """ # LightlyTrain Detection & Segmentation Demo 🚀 [GitHub](https://github.com/lightly-ai/lightly-train) | [Documentation](https://docs.lightly.ai/train) This demo showcases **LightlyTrain**, a powerful library for self-supervised learning and fine-tuning. Uses **DINOv3** backbones to detect objects or segment scenes (**COCO Classes**). """ # DEFINE MODELS DETECTION_MODELS = [ "dinov3/vitt16-ltdetr-coco", "dinov3/convnext-base-ltdetr-coco", "dinov3/convnext-small-ltdetr-coco", "dinov3/convnext-tiny-ltdetr-coco" ] SEGMENTATION_MODELS = [ "dinov3/vitb16-eomt-coco", "dinov3/vitl16-eomt-coco", "dinov3/vits16-eomt-coco" ] ALL_MODELS = DETECTION_MODELS + SEGMENTATION_MODELS DEFAULT_MODEL = DETECTION_MODELS[0] # 2. CLASS LISTS # COCO Detection (80 Classes) COCO_DETECTION_CLASSES = [ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush" ] # COCO-Stuff (171 Classes) COCO_STUFF_CLASSES = [ "unlabeled", # Index 0 (Background) "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "banner", "blanket", "branch", "bridge", "building-other", "bush", "cabinet", "cage", "cardboard", "carpet", "ceiling-other", "ceiling-tile", "cloth", "clothes", "clouds", "counter", "cupboard", "curtain", "desk-stuff", "dirt", "door-stuff", "fence", "floor-marble", "floor-other", "floor-stone", "floor-tile", "floor-wood", "flower", "fog", "food-other", "fruit", "furniture-other", "grass", "gravel", "ground-other", "hill", "house", "leaves", "light", "mat", "metal", "mirror-stuff", "moss", "mountain", "mud", "napkin", "net", "paper", "pavement", "pillow", "plant-other", "plastic", "platform", "playingfield", "railing", "railroad", "river", "road", "rock", "roof", "rug", "salad", "sand", "sea", "shelf", "sky-other", "skyscraper", "snow", "solid-other", "stairs", "stone", "straw", "structural-other", "table", "tent", "textile-other", "towel", "tree", "vegetable", "wall-brick", "wall-concrete", "wall-other", "wall-panel", "wall-stone", "wall-tile", "wall-wood", "water-other", "waterdrops", "window-blind", "window-other", "wood" ] # --- HELPER FUNCTIONS --- loaded_models = {} def get_model(model_name): if model_name in loaded_models: return loaded_models[model_name] print(f"Loading model: {model_name}...") try: model = lightly_train.load_model(model_name) loaded_models[model_name] = model return model except Exception as e: print(f"Error loading model: {e}") return None get_model(DEFAULT_MODEL) # --- INFERENCE LOGIC --- def run_prediction(image, confidence_threshold, resolution, model_name): if image is None: return None, None, None model = get_model(model_name) if model is None: return image, "Error loading model", {} image_input = image.resize((resolution, resolution)) if model_name in SEGMENTATION_MODELS: return run_segmentation(model, image_input, image) else: return run_detection(model, image_input, image, confidence_threshold) def run_detection(model, image_input, original_image, confidence_threshold): results = model.predict(image_input) boxes = results['bboxes'].cpu().numpy() labels = results['labels'].cpu().numpy() scores = results['scores'].cpu().numpy() # Filter valid = scores > confidence_threshold boxes = boxes[valid] labels = labels[valid] scores = scores[valid] detections = sv.Detections(xyxy=boxes, confidence=scores, class_id=labels) w_input, h_input = image_input.size w_orig, h_orig = original_image.size scale_x, scale_y = w_orig / w_input, h_orig / h_input detections.xyxy[:, 0] *= scale_x detections.xyxy[:, 1] *= scale_y detections.xyxy[:, 2] *= scale_x detections.xyxy[:, 3] *= scale_y box_annotator = sv.BoxAnnotator() label_annotator = sv.LabelAnnotator() labels_text = [] class_counts = {} for cid, conf in zip(detections.class_id, detections.confidence): name = COCO_DETECTION_CLASSES[cid] if cid < len(COCO_DETECTION_CLASSES) else f"Class {cid}" labels_text.append(f"{name} {conf:.2f}") class_counts[name] = class_counts.get(name, 0) + 1 annotated = original_image.copy() annotated = box_annotator.annotate(scene=annotated, detections=detections) annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels_text) summary_list = [f"{k}: {v}" for k, v in class_counts.items()] analytics_text = "Objects Found (Detection):\n" + (", ".join(summary_list) if summary_list else "None") return annotated, analytics_text, {"count": len(boxes), "objects": class_counts} # Segm code opt 1 def run_segmentation(model, image_input, original_image): mask_tensor = model.predict(image_input) mask_np = mask_tensor.cpu().numpy().astype(np.uint8) mask_np = cv2.resize(mask_np, original_image.size, interpolation=cv2.INTER_NEAREST) #current_classes = COCO_STUFF_CLASSES current_classes = ["unlabeled"] + COCO_DETECTION_CLASSES h, w = mask_np.shape colored_mask = np.zeros((h, w, 3), dtype=np.uint8) unique_classes = np.unique(mask_np) found_classes = set() labels_to_draw = [] for cls_id in unique_classes: # Skip background (Index 0) if cls_id == 0 or cls_id == 255: continue if cls_id < 0 or cls_id >= len(current_classes): continue class_name = current_classes[cls_id] found_classes.add(class_name) np.random.seed(int(cls_id)) color = np.random.randint(50, 255, size=3) colored_mask[mask_np == cls_id] = color y_indices, x_indices = np.where(mask_np == cls_id) # Filter small noise if len(y_indices) > 200: centroid_y = int(np.mean(y_indices)) centroid_x = int(np.mean(x_indices)) labels_to_draw.append((centroid_x, centroid_y, class_name)) original_np = np.array(original_image) blended = cv2.addWeighted(original_np, 0.6, colored_mask, 0.4, 0) for (cx, cy, text) in labels_to_draw: cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 3, cv2.LINE_AA) cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA) analytics_text = f"Scene Contains (COCO Objects):\n" + (", ".join(sorted(list(found_classes))) if found_classes else "None") return Image.fromarray(blended), analytics_text, {"classes_found": list(found_classes)} ''' # Segm code opt 2 def run_segmentation(model, image): """ Handles Segmentation: Returns Tensor of shape (H, W) with class IDs. """ mask_tensor = model.predict(image) mask_np = mask_tensor.cpu().numpy().astype(np.uint8) h, w = mask_np.shape colored_mask = np.zeros((h, w, 3), dtype=np.uint8) unique_classes = np.unique(mask_np) for cls_id in unique_classes: if cls_id == -1: continue np.random.seed(int(cls_id)) color = np.random.randint(50, 255, size=3) colored_mask[mask_np == cls_id] = color image_np = np.array(image) if image_np.shape[:2] != colored_mask.shape[:2]: colored_mask = cv2.resize(colored_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST) blended = cv2.addWeighted(image_np, 0.6, colored_mask, 0.4, 0) return Image.fromarray(blended) ''' # --- GRADIO UI --- theme = gr.themes.Soft( font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"] ) with gr.Blocks(theme=theme) as demo: gr.Markdown(MARKDOWN_HEADER) with gr.Row(): with gr.Column(scale=1): input_img = gr.Image(type="pil", label="Input Image") with gr.Accordion("Settings", open=True): conf_slider = gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="Confidence (Detection Only)") res_slider = gr.Slider(384, 1024, value=640, step=32, label="Inference Resolution") model_selector = gr.Dropdown( choices=ALL_MODELS, value=DEFAULT_MODEL, label="Model Checkpoint" ) run_btn = gr.Button("Analyze Image", variant="primary") with gr.Column(scale=1): output_img = gr.Image(label="Annotated Result") output_text = gr.Textbox(label="Analytics Summary", interactive=False, lines=6) with gr.Accordion("Raw Data (JSON)", open=False): output_json = gr.JSON(label="Detection Data") run_btn.click( fn=run_prediction, inputs=[input_img, conf_slider, res_slider, model_selector], outputs=[output_img, output_text, output_json] ) gr.Markdown("### 💡 Try an Example") gr.Examples( inputs=[input_img, conf_slider, res_slider, model_selector], examples=[ ["http://farm3.staticflickr.com/2547/3933456087_6a4dfb4736_z.jpg", 0.4, 640, DEFAULT_MODEL], ["https://farm3.staticflickr.com/2294/2193565429_aed7c9ff98_z.jpg", 0.4, 640, DEFAULT_MODEL], ["http://farm9.staticflickr.com/8092/8400332884_102a62b6c6_z.jpg", 0.4, 640, "dinov3/vitb16-eomt-coco"], ], outputs=[output_img, output_text, output_json], fn=run_prediction, cache_examples=False, ) if __name__ == "__main__": demo.launch()