File size: 2,819 Bytes
2bc3168 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | from __future__ import annotations
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from .obstacle_dataset import YoloBox
def draw_yolo_with_heatmap(
image: Image.Image,
yolo_boxes: list[YoloBox],
heatmap: np.ndarray,
saliency_threshold: float = 0.7,
) -> Image.Image:
canvas = overlay_heatmap(image, heatmap, saliency_threshold=saliency_threshold)
draw = ImageDraw.Draw(canvas)
font = ImageFont.load_default()
for box in yolo_boxes:
xyxy = box.to_xyxy(*canvas.size)
draw.rectangle(xyxy, outline="lime", width=4)
draw_label(draw, xyxy, f"YOLO: {box.class_name}", "lime", font)
return canvas
def overlay_heatmap(
image: Image.Image,
heatmap: np.ndarray,
saliency_threshold: float = 0.7,
) -> Image.Image:
base = image.convert("RGB")
heat = np.clip(heatmap, 0, 1)
heat_img = Image.fromarray((heat * 255).astype(np.uint8), mode="L").resize(base.size)
heat_values = np.asarray(heat_img).astype(np.float32) / 255.0
alpha_values = np.where(
heat_values >= saliency_threshold,
((heat_values - saliency_threshold) / max(1e-6, 1 - saliency_threshold)) * 210,
0,
)
alpha = Image.fromarray(alpha_values.astype(np.uint8), mode="L")
color = Image.new("RGBA", base.size, (255, 20, 0, 0))
color.putalpha(alpha)
return Image.alpha_composite(base.convert("RGBA"), color).convert("RGB")
def draw_detection_overlay(
image: Image.Image,
yolo_boxes: list[YoloBox],
ijepa_box: tuple[int, int, int, int] | None = None,
ijepa_candidate_boxes: list[tuple[int, int, int, int]] | None = None,
) -> Image.Image:
canvas = image.convert("RGB").copy()
draw = ImageDraw.Draw(canvas)
font = ImageFont.load_default()
for box in yolo_boxes:
xyxy = box.to_xyxy(*canvas.size)
draw.rectangle(xyxy, outline="lime", width=4)
draw_label(draw, xyxy, f"YOLO: {box.class_name}", "lime", font)
for index, candidate_box in enumerate(ijepa_candidate_boxes or [], start=1):
draw.rectangle(candidate_box, outline="red", width=3)
draw_label(draw, candidate_box, f"I-JEPA {index}", "red", font)
if ijepa_box is not None:
draw.rectangle(ijepa_box, outline="red", width=4)
draw_label(draw, ijepa_box, "I-JEPA estimate", "red", font)
return canvas
def draw_label(draw: ImageDraw.ImageDraw, box: tuple[int, int, int, int], text: str, color: str, font) -> None:
x1, y1, _, _ = box
text_box = draw.textbbox((x1, y1), text, font=font)
text_h = text_box[3] - text_box[1]
label_box = (x1, max(0, y1 - text_h - 6), x1 + text_box[2] - text_box[0] + 8, max(text_h + 6, y1))
draw.rectangle(label_box, fill=color)
draw.text((label_box[0] + 4, label_box[1] + 3), text, fill="black", font=font)
|