| from __future__ import annotations |
|
|
| import numpy as np |
| from PIL import Image, ImageDraw, ImageFont |
|
|
| from .obstacle_dataset import YoloBox |
|
|
|
|
| def draw_yolo_with_heatmap( |
| image: Image.Image, |
| yolo_boxes: list[YoloBox], |
| heatmap: np.ndarray, |
| saliency_threshold: float = 0.7, |
| ) -> Image.Image: |
| canvas = overlay_heatmap(image, heatmap, saliency_threshold=saliency_threshold) |
| draw = ImageDraw.Draw(canvas) |
| font = ImageFont.load_default() |
|
|
| for box in yolo_boxes: |
| xyxy = box.to_xyxy(*canvas.size) |
| draw.rectangle(xyxy, outline="lime", width=4) |
| draw_label(draw, xyxy, f"YOLO: {box.class_name}", "lime", font) |
|
|
| return canvas |
|
|
|
|
| def overlay_heatmap( |
| image: Image.Image, |
| heatmap: np.ndarray, |
| saliency_threshold: float = 0.7, |
| ) -> Image.Image: |
| base = image.convert("RGB") |
| heat = np.clip(heatmap, 0, 1) |
| heat_img = Image.fromarray((heat * 255).astype(np.uint8), mode="L").resize(base.size) |
| heat_values = np.asarray(heat_img).astype(np.float32) / 255.0 |
| alpha_values = np.where( |
| heat_values >= saliency_threshold, |
| ((heat_values - saliency_threshold) / max(1e-6, 1 - saliency_threshold)) * 210, |
| 0, |
| ) |
| alpha = Image.fromarray(alpha_values.astype(np.uint8), mode="L") |
|
|
| color = Image.new("RGBA", base.size, (255, 20, 0, 0)) |
| color.putalpha(alpha) |
| return Image.alpha_composite(base.convert("RGBA"), color).convert("RGB") |
|
|
|
|
| def draw_detection_overlay( |
| image: Image.Image, |
| yolo_boxes: list[YoloBox], |
| ijepa_box: tuple[int, int, int, int] | None = None, |
| ijepa_candidate_boxes: list[tuple[int, int, int, int]] | None = None, |
| ) -> Image.Image: |
| canvas = image.convert("RGB").copy() |
| draw = ImageDraw.Draw(canvas) |
| font = ImageFont.load_default() |
|
|
| for box in yolo_boxes: |
| xyxy = box.to_xyxy(*canvas.size) |
| draw.rectangle(xyxy, outline="lime", width=4) |
| draw_label(draw, xyxy, f"YOLO: {box.class_name}", "lime", font) |
|
|
| for index, candidate_box in enumerate(ijepa_candidate_boxes or [], start=1): |
| draw.rectangle(candidate_box, outline="red", width=3) |
| draw_label(draw, candidate_box, f"I-JEPA {index}", "red", font) |
|
|
| if ijepa_box is not None: |
| draw.rectangle(ijepa_box, outline="red", width=4) |
| draw_label(draw, ijepa_box, "I-JEPA estimate", "red", font) |
|
|
| return canvas |
|
|
|
|
| def draw_label(draw: ImageDraw.ImageDraw, box: tuple[int, int, int, int], text: str, color: str, font) -> None: |
| x1, y1, _, _ = box |
| text_box = draw.textbbox((x1, y1), text, font=font) |
| text_h = text_box[3] - text_box[1] |
| label_box = (x1, max(0, y1 - text_h - 6), x1 + text_box[2] - text_box[0] + 8, max(text_h + 6, y1)) |
| draw.rectangle(label_box, fill=color) |
| draw.text((label_box[0] + 4, label_box[1] + 3), text, fill="black", font=font) |
|
|