JEPA-demo / src /visualization.py
ddebree's picture
Upload folder using huggingface_hub
2bc3168 verified
from __future__ import annotations
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from .obstacle_dataset import YoloBox
def draw_yolo_with_heatmap(
image: Image.Image,
yolo_boxes: list[YoloBox],
heatmap: np.ndarray,
saliency_threshold: float = 0.7,
) -> Image.Image:
canvas = overlay_heatmap(image, heatmap, saliency_threshold=saliency_threshold)
draw = ImageDraw.Draw(canvas)
font = ImageFont.load_default()
for box in yolo_boxes:
xyxy = box.to_xyxy(*canvas.size)
draw.rectangle(xyxy, outline="lime", width=4)
draw_label(draw, xyxy, f"YOLO: {box.class_name}", "lime", font)
return canvas
def overlay_heatmap(
image: Image.Image,
heatmap: np.ndarray,
saliency_threshold: float = 0.7,
) -> Image.Image:
base = image.convert("RGB")
heat = np.clip(heatmap, 0, 1)
heat_img = Image.fromarray((heat * 255).astype(np.uint8), mode="L").resize(base.size)
heat_values = np.asarray(heat_img).astype(np.float32) / 255.0
alpha_values = np.where(
heat_values >= saliency_threshold,
((heat_values - saliency_threshold) / max(1e-6, 1 - saliency_threshold)) * 210,
0,
)
alpha = Image.fromarray(alpha_values.astype(np.uint8), mode="L")
color = Image.new("RGBA", base.size, (255, 20, 0, 0))
color.putalpha(alpha)
return Image.alpha_composite(base.convert("RGBA"), color).convert("RGB")
def draw_detection_overlay(
image: Image.Image,
yolo_boxes: list[YoloBox],
ijepa_box: tuple[int, int, int, int] | None = None,
ijepa_candidate_boxes: list[tuple[int, int, int, int]] | None = None,
) -> Image.Image:
canvas = image.convert("RGB").copy()
draw = ImageDraw.Draw(canvas)
font = ImageFont.load_default()
for box in yolo_boxes:
xyxy = box.to_xyxy(*canvas.size)
draw.rectangle(xyxy, outline="lime", width=4)
draw_label(draw, xyxy, f"YOLO: {box.class_name}", "lime", font)
for index, candidate_box in enumerate(ijepa_candidate_boxes or [], start=1):
draw.rectangle(candidate_box, outline="red", width=3)
draw_label(draw, candidate_box, f"I-JEPA {index}", "red", font)
if ijepa_box is not None:
draw.rectangle(ijepa_box, outline="red", width=4)
draw_label(draw, ijepa_box, "I-JEPA estimate", "red", font)
return canvas
def draw_label(draw: ImageDraw.ImageDraw, box: tuple[int, int, int, int], text: str, color: str, font) -> None:
x1, y1, _, _ = box
text_box = draw.textbbox((x1, y1), text, font=font)
text_h = text_box[3] - text_box[1]
label_box = (x1, max(0, y1 - text_h - 6), x1 + text_box[2] - text_box[0] + 8, max(text_h + 6, y1))
draw.rectangle(label_box, fill=color)
draw.text((label_box[0] + 4, label_box[1] + 3), text, fill="black", font=font)