Spaces:
Paused
Paused
| import numpy as np | |
| import matplotlib | |
| from PIL import Image, ImageDraw, ImageFont | |
| CLASS_COLORS = [ | |
| "#e6194b", "#3cb44b", "#4363d8", "#f58231", "#911eb4", | |
| "#42d4f4", "#f032e6", "#bfef45", "#fabed4", "#469990", | |
| ] | |
| LABEL_FONT_SIZE = 48 | |
| def _get_label_font(size: int = LABEL_FONT_SIZE) -> ImageFont.FreeTypeFont | ImageFont.ImageFont: | |
| """Return a TrueType font at the requested size, falling back to default.""" | |
| try: | |
| return ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", size) | |
| except OSError: | |
| try: | |
| return ImageFont.truetype("DejaVuSans.ttf", size) | |
| except OSError: | |
| return ImageFont.load_default(size=size) | |
| def _hex_to_rgb(hex_color: str) -> tuple[int, int, int]: | |
| h = hex_color.lstrip("#") | |
| return tuple(int(h[i:i+2], 16) for i in (0, 2, 4)) | |
| def _label_color_map(labels: list[str]) -> dict[str, tuple[int, int, int]]: | |
| """Map sorted unique labels to consistent colors from CLASS_COLORS.""" | |
| unique = sorted(set(l for l in labels if l)) | |
| mapping = {} | |
| for i, label in enumerate(unique): | |
| mapping[label] = _hex_to_rgb(CLASS_COLORS[i % len(CLASS_COLORS)]) | |
| # Unlabeled detections get gray | |
| mapping[""] = (180, 180, 180) | |
| return mapping | |
| def overlay_detections_by_class(image: Image.Image, detections: list[dict], | |
| show_index: bool = False, | |
| color_override: dict[str, tuple[int, int, int]] | None = None, | |
| highlight_ids: set[int] | None = None) -> Image.Image: | |
| """Overlay detections colored by label. | |
| Args: | |
| color_override: When provided, maps label strings to RGB tuples. | |
| Used to keep overlay colors matching class card colors exactly. | |
| highlight_ids: Set of detection ``id`` values to draw with a bright | |
| highlight (thicker border, higher-alpha mask) so the user can | |
| see which detection is selected in the class panel. | |
| """ | |
| if not detections: | |
| return image | |
| image = image.convert("RGBA") | |
| labels = [d.get("label", "") for d in detections] | |
| color_map = color_override if color_override else _label_color_map(labels) | |
| font = _get_label_font() | |
| # Draw masks | |
| for det in detections: | |
| label = det.get("label", "") | |
| color = color_map.get(label, (180, 180, 180)) | |
| is_highlighted = highlight_ids and det.get("id") in highlight_ids | |
| mask = det.get("mask") | |
| if mask is not None: | |
| mask_uint8 = (255 * mask.astype(np.uint8)) | |
| mask_img = Image.fromarray(mask_uint8) | |
| overlay = Image.new("RGBA", image.size, color + (0,)) | |
| alpha_factor = 0.8 if is_highlighted else 0.5 | |
| alpha = mask_img.point(lambda v, af=alpha_factor: int(v * af)) | |
| overlay.putalpha(alpha) | |
| image = Image.alpha_composite(image, overlay) | |
| # Draw boxes and labels | |
| draw = ImageDraw.Draw(image) | |
| for i, det in enumerate(detections): | |
| label = det.get("label", "") | |
| color = color_map.get(label, (180, 180, 180)) | |
| is_highlighted = highlight_ids and det.get("id") in highlight_ids | |
| x1, y1, x2, y2 = det["box"] | |
| box_width = 8 if is_highlighted else 3 | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=box_width) | |
| # Bright yellow highlight border on top | |
| if is_highlighted: | |
| draw.rectangle([x1 - 2, y1 - 2, x2 + 2, y2 + 2], | |
| outline=(255, 255, 0), width=4) | |
| text = "" | |
| if show_index and label: | |
| text = f"Detection {i+1}: {label}" | |
| elif show_index: | |
| text = f"Detection {i+1}" | |
| elif label: | |
| text = label | |
| if text: | |
| # Draw text with background for readability | |
| bbox = font.getbbox(text) | |
| tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1] | |
| ty = max(y1 - th - 6, 0) | |
| draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], | |
| fill=(0, 0, 0, 180)) | |
| draw.text((x1 + 2, ty + 2), text, fill=color, font=font) | |
| return image | |
| def overlay_masks(image: Image.Image, masks: list[np.ndarray]) -> Image.Image: | |
| """Overlay multiple masks on an image with rainbow colormap at 50% alpha.""" | |
| image = image.convert("RGBA") | |
| n_masks = len(masks) | |
| if n_masks == 0: | |
| return image | |
| cmap = matplotlib.colormaps.get_cmap("rainbow").resampled(n_masks) | |
| colors = [ | |
| tuple(int(c * 255) for c in cmap(i)[:3]) | |
| for i in range(n_masks) | |
| ] | |
| for mask, color in zip(masks, colors): | |
| mask_uint8 = (255 * mask.astype(np.uint8)) | |
| mask_img = Image.fromarray(mask_uint8) | |
| overlay = Image.new("RGBA", image.size, color + (0,)) | |
| alpha = mask_img.point(lambda v: int(v * 0.5)) | |
| overlay.putalpha(alpha) | |
| image = Image.alpha_composite(image, overlay) | |
| return image | |
| def overlay_boxes(image: Image.Image, boxes: list, labels: list[str] | None = None) -> Image.Image: | |
| """Draw red bounding boxes with optional label text.""" | |
| image = image.convert("RGBA") | |
| draw = ImageDraw.Draw(image) | |
| for i, box in enumerate(boxes): | |
| x1, y1, x2, y2 = box | |
| draw.rectangle([x1, y1, x2, y2], outline="red", width=3) | |
| if labels and i < len(labels) and labels[i]: | |
| draw.text((x1, max(y1 - 14, 0)), labels[i], fill="red") | |
| return image | |
| def overlay_single_detection(image: Image.Image, detection: dict) -> Image.Image: | |
| """Overlay a single detection's mask, box, and score for review.""" | |
| vis = overlay_masks(image, [detection["mask"]]) | |
| vis = overlay_boxes(vis, [detection["box"]]) | |
| draw = ImageDraw.Draw(vis) | |
| x1, y1 = detection["box"][:2] | |
| score_text = f"score: {detection['score']:.3f}" | |
| draw.text((x1, max(y1 - 28, 0)), score_text, fill="yellow") | |
| return vis | |
| def overlay_accepted(image: Image.Image, accepted_dets: list[dict]) -> Image.Image: | |
| """Draw green bounding boxes and labels for accepted detections.""" | |
| image = image.convert("RGBA") | |
| if not accepted_dets: | |
| return image | |
| draw = ImageDraw.Draw(image) | |
| for det in accepted_dets: | |
| x1, y1, x2, y2 = det["box"] | |
| draw.rectangle([x1, y1, x2, y2], outline="lime", width=3) | |
| label = det.get("label", "") | |
| if label: | |
| draw.text((x1, max(y1 - 14, 0)), label, fill="lime") | |
| return image | |