Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import supervision as sv | |
| import torch | |
| import cv2 | |
| from PIL import Image | |
| import lightly_train | |
| # --- CONFIGURATION --- | |
| MARKDOWN_HEADER = """ | |
| # LightlyTrain Detection & Segmentation Demo 🚀 | |
| [GitHub](https://github.com/lightly-ai/lightly-train) | [Documentation](https://docs.lightly.ai/train) | |
| This demo showcases **LightlyTrain**, a powerful library for self-supervised learning and fine-tuning. | |
| Uses **DINOv3** backbones to detect objects or segment scenes (**COCO Classes**). | |
| """ | |
| # DEFINE MODELS | |
| DETECTION_MODELS = [ | |
| "dinov3/vitt16-ltdetr-coco", | |
| "dinov3/convnext-base-ltdetr-coco", | |
| "dinov3/convnext-small-ltdetr-coco", | |
| "dinov3/convnext-tiny-ltdetr-coco" | |
| ] | |
| SEGMENTATION_MODELS = [ | |
| "dinov3/vitb16-eomt-coco", | |
| "dinov3/vitl16-eomt-coco", | |
| "dinov3/vits16-eomt-coco" | |
| ] | |
| ALL_MODELS = DETECTION_MODELS + SEGMENTATION_MODELS | |
| DEFAULT_MODEL = DETECTION_MODELS[0] | |
| # 2. CLASS LISTS | |
| # COCO Detection (80 Classes) | |
| COCO_DETECTION_CLASSES = [ | |
| "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", | |
| "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", | |
| "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", | |
| "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", | |
| "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", | |
| "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", | |
| "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", | |
| "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", | |
| "scissors", "teddy bear", "hair drier", "toothbrush" | |
| ] | |
| # COCO-Stuff (171 Classes) | |
| COCO_STUFF_CLASSES = [ | |
| "unlabeled", # Index 0 (Background) | |
| "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", | |
| "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", | |
| "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", | |
| "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", | |
| "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", | |
| "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", | |
| "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", | |
| "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", | |
| "scissors", "teddy bear", "hair drier", "toothbrush", "banner", "blanket", "branch", "bridge", "building-other", | |
| "bush", "cabinet", "cage", "cardboard", "carpet", "ceiling-other", "ceiling-tile", "cloth", "clothes", "clouds", | |
| "counter", "cupboard", "curtain", "desk-stuff", "dirt", "door-stuff", "fence", "floor-marble", "floor-other", | |
| "floor-stone", "floor-tile", "floor-wood", "flower", "fog", "food-other", "fruit", "furniture-other", "grass", | |
| "gravel", "ground-other", "hill", "house", "leaves", "light", "mat", "metal", "mirror-stuff", "moss", "mountain", | |
| "mud", "napkin", "net", "paper", "pavement", "pillow", "plant-other", "plastic", "platform", "playingfield", | |
| "railing", "railroad", "river", "road", "rock", "roof", "rug", "salad", "sand", "sea", "shelf", "sky-other", | |
| "skyscraper", "snow", "solid-other", "stairs", "stone", "straw", "structural-other", "table", "tent", "textile-other", | |
| "towel", "tree", "vegetable", "wall-brick", "wall-concrete", "wall-other", "wall-panel", "wall-stone", "wall-tile", | |
| "wall-wood", "water-other", "waterdrops", "window-blind", "window-other", "wood" | |
| ] | |
| # --- HELPER FUNCTIONS --- | |
| loaded_models = {} | |
| def get_model(model_name): | |
| if model_name in loaded_models: | |
| return loaded_models[model_name] | |
| print(f"Loading model: {model_name}...") | |
| try: | |
| model = lightly_train.load_model(model_name) | |
| loaded_models[model_name] = model | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None | |
| get_model(DEFAULT_MODEL) | |
| # --- INFERENCE LOGIC --- | |
| def run_prediction(image, confidence_threshold, resolution, model_name): | |
| if image is None: return None, None, None | |
| model = get_model(model_name) | |
| if model is None: return image, "Error loading model", {} | |
| image_input = image.resize((resolution, resolution)) | |
| if model_name in SEGMENTATION_MODELS: | |
| return run_segmentation(model, image_input, image) | |
| else: | |
| return run_detection(model, image_input, image, confidence_threshold) | |
| def run_detection(model, image_input, original_image, confidence_threshold): | |
| results = model.predict(image_input) | |
| boxes = results['bboxes'].cpu().numpy() | |
| labels = results['labels'].cpu().numpy() | |
| scores = results['scores'].cpu().numpy() | |
| # Filter | |
| valid = scores > confidence_threshold | |
| boxes = boxes[valid] | |
| labels = labels[valid] | |
| scores = scores[valid] | |
| detections = sv.Detections(xyxy=boxes, confidence=scores, class_id=labels) | |
| w_input, h_input = image_input.size | |
| w_orig, h_orig = original_image.size | |
| scale_x, scale_y = w_orig / w_input, h_orig / h_input | |
| detections.xyxy[:, 0] *= scale_x | |
| detections.xyxy[:, 1] *= scale_y | |
| detections.xyxy[:, 2] *= scale_x | |
| detections.xyxy[:, 3] *= scale_y | |
| box_annotator = sv.BoxAnnotator() | |
| label_annotator = sv.LabelAnnotator() | |
| labels_text = [] | |
| class_counts = {} | |
| for cid, conf in zip(detections.class_id, detections.confidence): | |
| name = COCO_DETECTION_CLASSES[cid] if cid < len(COCO_DETECTION_CLASSES) else f"Class {cid}" | |
| labels_text.append(f"{name} {conf:.2f}") | |
| class_counts[name] = class_counts.get(name, 0) + 1 | |
| annotated = original_image.copy() | |
| annotated = box_annotator.annotate(scene=annotated, detections=detections) | |
| annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels_text) | |
| summary_list = [f"{k}: {v}" for k, v in class_counts.items()] | |
| analytics_text = "Objects Found (Detection):\n" + (", ".join(summary_list) if summary_list else "None") | |
| return annotated, analytics_text, {"count": len(boxes), "objects": class_counts} | |
| # Segm code opt 1 | |
| def run_segmentation(model, image_input, original_image): | |
| mask_tensor = model.predict(image_input) | |
| mask_np = mask_tensor.cpu().numpy().astype(np.uint8) | |
| mask_np = cv2.resize(mask_np, original_image.size, interpolation=cv2.INTER_NEAREST) | |
| #current_classes = COCO_STUFF_CLASSES | |
| current_classes = ["unlabeled"] + COCO_DETECTION_CLASSES | |
| h, w = mask_np.shape | |
| colored_mask = np.zeros((h, w, 3), dtype=np.uint8) | |
| unique_classes = np.unique(mask_np) | |
| found_classes = set() | |
| labels_to_draw = [] | |
| for cls_id in unique_classes: | |
| # Skip background (Index 0) | |
| if cls_id == 0 or cls_id == 255: continue | |
| if cls_id < 0 or cls_id >= len(current_classes): continue | |
| class_name = current_classes[cls_id] | |
| found_classes.add(class_name) | |
| np.random.seed(int(cls_id)) | |
| color = np.random.randint(50, 255, size=3) | |
| colored_mask[mask_np == cls_id] = color | |
| y_indices, x_indices = np.where(mask_np == cls_id) | |
| # Filter small noise | |
| if len(y_indices) > 200: | |
| centroid_y = int(np.mean(y_indices)) | |
| centroid_x = int(np.mean(x_indices)) | |
| labels_to_draw.append((centroid_x, centroid_y, class_name)) | |
| original_np = np.array(original_image) | |
| blended = cv2.addWeighted(original_np, 0.6, colored_mask, 0.4, 0) | |
| for (cx, cy, text) in labels_to_draw: | |
| cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 3, cv2.LINE_AA) | |
| cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA) | |
| analytics_text = f"Scene Contains (COCO Objects):\n" + (", ".join(sorted(list(found_classes))) if found_classes else "None") | |
| return Image.fromarray(blended), analytics_text, {"classes_found": list(found_classes)} | |
| ''' | |
| # Segm code opt 2 | |
| def run_segmentation(model, image): | |
| """ | |
| Handles Segmentation: Returns Tensor of shape (H, W) with class IDs. | |
| """ | |
| mask_tensor = model.predict(image) | |
| mask_np = mask_tensor.cpu().numpy().astype(np.uint8) | |
| h, w = mask_np.shape | |
| colored_mask = np.zeros((h, w, 3), dtype=np.uint8) | |
| unique_classes = np.unique(mask_np) | |
| for cls_id in unique_classes: | |
| if cls_id == -1: continue | |
| np.random.seed(int(cls_id)) | |
| color = np.random.randint(50, 255, size=3) | |
| colored_mask[mask_np == cls_id] = color | |
| image_np = np.array(image) | |
| if image_np.shape[:2] != colored_mask.shape[:2]: | |
| colored_mask = cv2.resize(colored_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| blended = cv2.addWeighted(image_np, 0.6, colored_mask, 0.4, 0) | |
| return Image.fromarray(blended) | |
| ''' | |
| # --- GRADIO UI --- | |
| theme = gr.themes.Soft( | |
| font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"] | |
| ) | |
| with gr.Blocks(theme=theme) as demo: | |
| gr.Markdown(MARKDOWN_HEADER) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_img = gr.Image(type="pil", label="Input Image") | |
| with gr.Accordion("Settings", open=True): | |
| conf_slider = gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="Confidence (Detection Only)") | |
| res_slider = gr.Slider(384, 1024, value=640, step=32, label="Inference Resolution") | |
| model_selector = gr.Dropdown( | |
| choices=ALL_MODELS, | |
| value=DEFAULT_MODEL, | |
| label="Model Checkpoint" | |
| ) | |
| run_btn = gr.Button("Analyze Image", variant="primary") | |
| with gr.Column(scale=1): | |
| output_img = gr.Image(label="Annotated Result") | |
| output_text = gr.Textbox(label="Analytics Summary", interactive=False, lines=6) | |
| with gr.Accordion("Raw Data (JSON)", open=False): | |
| output_json = gr.JSON(label="Detection Data") | |
| run_btn.click( | |
| fn=run_prediction, | |
| inputs=[input_img, conf_slider, res_slider, model_selector], | |
| outputs=[output_img, output_text, output_json] | |
| ) | |
| gr.Markdown("### 💡 Try an Example") | |
| gr.Examples( | |
| inputs=[input_img, conf_slider, res_slider, model_selector], | |
| examples=[ | |
| ["http://farm3.staticflickr.com/2547/3933456087_6a4dfb4736_z.jpg", 0.4, 640, DEFAULT_MODEL], | |
| ["https://farm3.staticflickr.com/2294/2193565429_aed7c9ff98_z.jpg", 0.4, 640, DEFAULT_MODEL], | |
| ["http://farm9.staticflickr.com/8092/8400332884_102a62b6c6_z.jpg", 0.4, 640, "dinov3/vitb16-eomt-coco"], | |
| ], | |
| outputs=[output_img, output_text, output_json], | |
| fn=run_prediction, | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |