masakljun's picture
obj detection seems ok, while segm is off
4a3ef8c
import gradio as gr
import numpy as np
import supervision as sv
import torch
import cv2
from PIL import Image
import lightly_train
# --- CONFIGURATION ---
MARKDOWN_HEADER = """
# LightlyTrain Detection & Segmentation Demo 🚀
[GitHub](https://github.com/lightly-ai/lightly-train) | [Documentation](https://docs.lightly.ai/train)
This demo showcases **LightlyTrain**, a powerful library for self-supervised learning and fine-tuning.
Uses **DINOv3** backbones to detect objects or segment scenes (**COCO Classes**).
"""
# DEFINE MODELS
DETECTION_MODELS = [
"dinov3/vitt16-ltdetr-coco",
"dinov3/convnext-base-ltdetr-coco",
"dinov3/convnext-small-ltdetr-coco",
"dinov3/convnext-tiny-ltdetr-coco"
]
SEGMENTATION_MODELS = [
"dinov3/vitb16-eomt-coco",
"dinov3/vitl16-eomt-coco",
"dinov3/vits16-eomt-coco"
]
ALL_MODELS = DETECTION_MODELS + SEGMENTATION_MODELS
DEFAULT_MODEL = DETECTION_MODELS[0]
# 2. CLASS LISTS
# COCO Detection (80 Classes)
COCO_DETECTION_CLASSES = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard",
"cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush"
]
# COCO-Stuff (171 Classes)
COCO_STUFF_CLASSES = [
"unlabeled", # Index 0 (Background)
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard",
"cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush", "banner", "blanket", "branch", "bridge", "building-other",
"bush", "cabinet", "cage", "cardboard", "carpet", "ceiling-other", "ceiling-tile", "cloth", "clothes", "clouds",
"counter", "cupboard", "curtain", "desk-stuff", "dirt", "door-stuff", "fence", "floor-marble", "floor-other",
"floor-stone", "floor-tile", "floor-wood", "flower", "fog", "food-other", "fruit", "furniture-other", "grass",
"gravel", "ground-other", "hill", "house", "leaves", "light", "mat", "metal", "mirror-stuff", "moss", "mountain",
"mud", "napkin", "net", "paper", "pavement", "pillow", "plant-other", "plastic", "platform", "playingfield",
"railing", "railroad", "river", "road", "rock", "roof", "rug", "salad", "sand", "sea", "shelf", "sky-other",
"skyscraper", "snow", "solid-other", "stairs", "stone", "straw", "structural-other", "table", "tent", "textile-other",
"towel", "tree", "vegetable", "wall-brick", "wall-concrete", "wall-other", "wall-panel", "wall-stone", "wall-tile",
"wall-wood", "water-other", "waterdrops", "window-blind", "window-other", "wood"
]
# --- HELPER FUNCTIONS ---
loaded_models = {}
def get_model(model_name):
if model_name in loaded_models:
return loaded_models[model_name]
print(f"Loading model: {model_name}...")
try:
model = lightly_train.load_model(model_name)
loaded_models[model_name] = model
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
get_model(DEFAULT_MODEL)
# --- INFERENCE LOGIC ---
def run_prediction(image, confidence_threshold, resolution, model_name):
if image is None: return None, None, None
model = get_model(model_name)
if model is None: return image, "Error loading model", {}
image_input = image.resize((resolution, resolution))
if model_name in SEGMENTATION_MODELS:
return run_segmentation(model, image_input, image)
else:
return run_detection(model, image_input, image, confidence_threshold)
def run_detection(model, image_input, original_image, confidence_threshold):
results = model.predict(image_input)
boxes = results['bboxes'].cpu().numpy()
labels = results['labels'].cpu().numpy()
scores = results['scores'].cpu().numpy()
# Filter
valid = scores > confidence_threshold
boxes = boxes[valid]
labels = labels[valid]
scores = scores[valid]
detections = sv.Detections(xyxy=boxes, confidence=scores, class_id=labels)
w_input, h_input = image_input.size
w_orig, h_orig = original_image.size
scale_x, scale_y = w_orig / w_input, h_orig / h_input
detections.xyxy[:, 0] *= scale_x
detections.xyxy[:, 1] *= scale_y
detections.xyxy[:, 2] *= scale_x
detections.xyxy[:, 3] *= scale_y
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
labels_text = []
class_counts = {}
for cid, conf in zip(detections.class_id, detections.confidence):
name = COCO_DETECTION_CLASSES[cid] if cid < len(COCO_DETECTION_CLASSES) else f"Class {cid}"
labels_text.append(f"{name} {conf:.2f}")
class_counts[name] = class_counts.get(name, 0) + 1
annotated = original_image.copy()
annotated = box_annotator.annotate(scene=annotated, detections=detections)
annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels_text)
summary_list = [f"{k}: {v}" for k, v in class_counts.items()]
analytics_text = "Objects Found (Detection):\n" + (", ".join(summary_list) if summary_list else "None")
return annotated, analytics_text, {"count": len(boxes), "objects": class_counts}
# Segm code opt 1
def run_segmentation(model, image_input, original_image):
mask_tensor = model.predict(image_input)
mask_np = mask_tensor.cpu().numpy().astype(np.uint8)
mask_np = cv2.resize(mask_np, original_image.size, interpolation=cv2.INTER_NEAREST)
#current_classes = COCO_STUFF_CLASSES
current_classes = ["unlabeled"] + COCO_DETECTION_CLASSES
h, w = mask_np.shape
colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
unique_classes = np.unique(mask_np)
found_classes = set()
labels_to_draw = []
for cls_id in unique_classes:
# Skip background (Index 0)
if cls_id == 0 or cls_id == 255: continue
if cls_id < 0 or cls_id >= len(current_classes): continue
class_name = current_classes[cls_id]
found_classes.add(class_name)
np.random.seed(int(cls_id))
color = np.random.randint(50, 255, size=3)
colored_mask[mask_np == cls_id] = color
y_indices, x_indices = np.where(mask_np == cls_id)
# Filter small noise
if len(y_indices) > 200:
centroid_y = int(np.mean(y_indices))
centroid_x = int(np.mean(x_indices))
labels_to_draw.append((centroid_x, centroid_y, class_name))
original_np = np.array(original_image)
blended = cv2.addWeighted(original_np, 0.6, colored_mask, 0.4, 0)
for (cx, cy, text) in labels_to_draw:
cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 3, cv2.LINE_AA)
cv2.putText(blended, text, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA)
analytics_text = f"Scene Contains (COCO Objects):\n" + (", ".join(sorted(list(found_classes))) if found_classes else "None")
return Image.fromarray(blended), analytics_text, {"classes_found": list(found_classes)}
'''
# Segm code opt 2
def run_segmentation(model, image):
"""
Handles Segmentation: Returns Tensor of shape (H, W) with class IDs.
"""
mask_tensor = model.predict(image)
mask_np = mask_tensor.cpu().numpy().astype(np.uint8)
h, w = mask_np.shape
colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
unique_classes = np.unique(mask_np)
for cls_id in unique_classes:
if cls_id == -1: continue
np.random.seed(int(cls_id))
color = np.random.randint(50, 255, size=3)
colored_mask[mask_np == cls_id] = color
image_np = np.array(image)
if image_np.shape[:2] != colored_mask.shape[:2]:
colored_mask = cv2.resize(colored_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
blended = cv2.addWeighted(image_np, 0.6, colored_mask, 0.4, 0)
return Image.fromarray(blended)
'''
# --- GRADIO UI ---
theme = gr.themes.Soft(
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
)
with gr.Blocks(theme=theme) as demo:
gr.Markdown(MARKDOWN_HEADER)
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(type="pil", label="Input Image")
with gr.Accordion("Settings", open=True):
conf_slider = gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="Confidence (Detection Only)")
res_slider = gr.Slider(384, 1024, value=640, step=32, label="Inference Resolution")
model_selector = gr.Dropdown(
choices=ALL_MODELS,
value=DEFAULT_MODEL,
label="Model Checkpoint"
)
run_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column(scale=1):
output_img = gr.Image(label="Annotated Result")
output_text = gr.Textbox(label="Analytics Summary", interactive=False, lines=6)
with gr.Accordion("Raw Data (JSON)", open=False):
output_json = gr.JSON(label="Detection Data")
run_btn.click(
fn=run_prediction,
inputs=[input_img, conf_slider, res_slider, model_selector],
outputs=[output_img, output_text, output_json]
)
gr.Markdown("### 💡 Try an Example")
gr.Examples(
inputs=[input_img, conf_slider, res_slider, model_selector],
examples=[
["http://farm3.staticflickr.com/2547/3933456087_6a4dfb4736_z.jpg", 0.4, 640, DEFAULT_MODEL],
["https://farm3.staticflickr.com/2294/2193565429_aed7c9ff98_z.jpg", 0.4, 640, DEFAULT_MODEL],
["http://farm9.staticflickr.com/8092/8400332884_102a62b6c6_z.jpg", 0.4, 640, "dinov3/vitb16-eomt-coco"],
],
outputs=[output_img, output_text, output_json],
fn=run_prediction,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()