demomule / viz.py
righthook75's picture
Upload viz.py with huggingface_hub
76df059 verified
import numpy as np
import matplotlib
from PIL import Image, ImageDraw, ImageFont
CLASS_COLORS = [
"#e6194b", "#3cb44b", "#4363d8", "#f58231", "#911eb4",
"#42d4f4", "#f032e6", "#bfef45", "#fabed4", "#469990",
]
LABEL_FONT_SIZE = 48
def _get_label_font(size: int = LABEL_FONT_SIZE) -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
"""Return a TrueType font at the requested size, falling back to default."""
try:
return ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", size)
except OSError:
try:
return ImageFont.truetype("DejaVuSans.ttf", size)
except OSError:
return ImageFont.load_default(size=size)
def _hex_to_rgb(hex_color: str) -> tuple[int, int, int]:
h = hex_color.lstrip("#")
return tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
def _label_color_map(labels: list[str]) -> dict[str, tuple[int, int, int]]:
"""Map sorted unique labels to consistent colors from CLASS_COLORS."""
unique = sorted(set(l for l in labels if l))
mapping = {}
for i, label in enumerate(unique):
mapping[label] = _hex_to_rgb(CLASS_COLORS[i % len(CLASS_COLORS)])
# Unlabeled detections get gray
mapping[""] = (180, 180, 180)
return mapping
def overlay_detections_by_class(image: Image.Image, detections: list[dict],
show_index: bool = False,
color_override: dict[str, tuple[int, int, int]] | None = None,
highlight_ids: set[int] | None = None) -> Image.Image:
"""Overlay detections colored by label.
Args:
color_override: When provided, maps label strings to RGB tuples.
Used to keep overlay colors matching class card colors exactly.
highlight_ids: Set of detection ``id`` values to draw with a bright
highlight (thicker border, higher-alpha mask) so the user can
see which detection is selected in the class panel.
"""
if not detections:
return image
image = image.convert("RGBA")
labels = [d.get("label", "") for d in detections]
color_map = color_override if color_override else _label_color_map(labels)
font = _get_label_font()
# Draw masks
for det in detections:
label = det.get("label", "")
color = color_map.get(label, (180, 180, 180))
is_highlighted = highlight_ids and det.get("id") in highlight_ids
mask = det.get("mask")
if mask is not None:
mask_uint8 = (255 * mask.astype(np.uint8))
mask_img = Image.fromarray(mask_uint8)
overlay = Image.new("RGBA", image.size, color + (0,))
alpha_factor = 0.8 if is_highlighted else 0.5
alpha = mask_img.point(lambda v, af=alpha_factor: int(v * af))
overlay.putalpha(alpha)
image = Image.alpha_composite(image, overlay)
# Draw boxes and labels
draw = ImageDraw.Draw(image)
for i, det in enumerate(detections):
label = det.get("label", "")
color = color_map.get(label, (180, 180, 180))
is_highlighted = highlight_ids and det.get("id") in highlight_ids
x1, y1, x2, y2 = det["box"]
box_width = 8 if is_highlighted else 3
draw.rectangle([x1, y1, x2, y2], outline=color, width=box_width)
# Bright yellow highlight border on top
if is_highlighted:
draw.rectangle([x1 - 2, y1 - 2, x2 + 2, y2 + 2],
outline=(255, 255, 0), width=4)
text = ""
if show_index and label:
text = f"Detection {i+1}: {label}"
elif show_index:
text = f"Detection {i+1}"
elif label:
text = label
if text:
# Draw text with background for readability
bbox = font.getbbox(text)
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
ty = max(y1 - th - 6, 0)
draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4],
fill=(0, 0, 0, 180))
draw.text((x1 + 2, ty + 2), text, fill=color, font=font)
return image
def overlay_masks(image: Image.Image, masks: list[np.ndarray]) -> Image.Image:
"""Overlay multiple masks on an image with rainbow colormap at 50% alpha."""
image = image.convert("RGBA")
n_masks = len(masks)
if n_masks == 0:
return image
cmap = matplotlib.colormaps.get_cmap("rainbow").resampled(n_masks)
colors = [
tuple(int(c * 255) for c in cmap(i)[:3])
for i in range(n_masks)
]
for mask, color in zip(masks, colors):
mask_uint8 = (255 * mask.astype(np.uint8))
mask_img = Image.fromarray(mask_uint8)
overlay = Image.new("RGBA", image.size, color + (0,))
alpha = mask_img.point(lambda v: int(v * 0.5))
overlay.putalpha(alpha)
image = Image.alpha_composite(image, overlay)
return image
def overlay_boxes(image: Image.Image, boxes: list, labels: list[str] | None = None) -> Image.Image:
"""Draw red bounding boxes with optional label text."""
image = image.convert("RGBA")
draw = ImageDraw.Draw(image)
for i, box in enumerate(boxes):
x1, y1, x2, y2 = box
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
if labels and i < len(labels) and labels[i]:
draw.text((x1, max(y1 - 14, 0)), labels[i], fill="red")
return image
def overlay_single_detection(image: Image.Image, detection: dict) -> Image.Image:
"""Overlay a single detection's mask, box, and score for review."""
vis = overlay_masks(image, [detection["mask"]])
vis = overlay_boxes(vis, [detection["box"]])
draw = ImageDraw.Draw(vis)
x1, y1 = detection["box"][:2]
score_text = f"score: {detection['score']:.3f}"
draw.text((x1, max(y1 - 28, 0)), score_text, fill="yellow")
return vis
def overlay_accepted(image: Image.Image, accepted_dets: list[dict]) -> Image.Image:
"""Draw green bounding boxes and labels for accepted detections."""
image = image.convert("RGBA")
if not accepted_dets:
return image
draw = ImageDraw.Draw(image)
for det in accepted_dets:
x1, y1, x2, y2 = det["box"]
draw.rectangle([x1, y1, x2, y2], outline="lime", width=3)
label = det.get("label", "")
if label:
draw.text((x1, max(y1 - 14, 0)), label, fill="lime")
return image