""" utils/image_utils.py -------------------- Image I/O, mask manipulation, and debug-image helpers. """ import os import math import zipfile import requests import urllib.request from pathlib import Path from typing import List, Tuple, Optional, Union import cv2 import numpy as np from PIL import Image from tqdm import tqdm import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches # -- Image I/O ----------------------------------------------------------------- def load_image_pil(path: str) -> Image.Image: """Load image as PIL RGB.""" return Image.open(path).convert("RGB") def load_image_cv2(path: str) -> np.ndarray: """Load image as OpenCV BGR numpy array.""" img = cv2.imread(path) if img is None: raise FileNotFoundError(f"Cannot read image: {path}") return img def pil_to_cv2(img: Image.Image) -> np.ndarray: return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) def cv2_to_pil(img: np.ndarray) -> Image.Image: return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) # Alias for compatibility load_image = load_image_pil def show_mask(mask, ax, random_color=False): """Stub for SAM visualization.""" pass def show_box(box, ax): """Stub for SAM visualization.""" pass def dilate_mask_with_sam_prediction(mask, dilation_px): """Stub for SAM-based dilation.""" return mask def save_image(img: Union[Image.Image, np.ndarray], path: str) -> None: os.makedirs(os.path.dirname(path) or ".", exist_ok=True) if isinstance(img, np.ndarray): cv2.imwrite(path, img) else: img.save(path) def list_images(directory: str) -> List[str]: """Return sorted list of image file paths in a directory.""" exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff"} paths = sorted( str(p) for p in Path(directory).iterdir() if p.suffix.lower() in exts ) return paths # -- Mask operations ----------------------------------------------------------- def boxes_to_mask( boxes: List[Tuple[int, int, int, int]], h: int, w: int, dilation_px: int = 0, ) -> np.ndarray: """ Convert list of (x1,y1,x2,y2) boxes to a binary uint8 mask (HW). Optionally dilate the mask by `dilation_px` pixels. """ mask = np.zeros((h, w), dtype=np.uint8) for x1, y1, x2, y2 in boxes: x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(w, x2), min(h, y2) mask[y1:y2, x1:x2] = 255 if dilation_px > 0: kernel = cv2.getStructuringElement( cv2.MORPH_ELLIPSE, (dilation_px * 2 + 1, dilation_px * 2 + 1) ) mask = cv2.dilate(mask, kernel) return mask def combine_masks(masks: List[np.ndarray]) -> np.ndarray: """OR-combine a list of binary uint8 masks.""" if not masks: raise ValueError("Empty mask list") out = np.zeros_like(masks[0]) for m in masks: out = cv2.bitwise_or(out, m) return out def refine_mask_with_sam_prediction( raw_mask: np.ndarray, sam_masks: List[np.ndarray], ) -> np.ndarray: """ Given SAM predicted masks (each boolean HW), pick the one with the highest IoU against the raw_mask and return it as uint8. """ best_mask = raw_mask best_iou = 0.0 raw_bool = raw_mask.astype(bool) for m in sam_masks: m_bool = m.astype(bool) intersection = (raw_bool & m_bool).sum() union = (raw_bool | m_bool).sum() iou = intersection / (union + 1e-8) if iou > best_iou: best_iou = iou best_mask = (m_bool.astype(np.uint8)) * 255 return best_mask def dilate_mask(mask: np.ndarray, px: int) -> np.ndarray: if px <= 0: return mask kernel = cv2.getStructuringElement( cv2.MORPH_ELLIPSE, (px * 2 + 1, px * 2 + 1) ) return cv2.dilate(mask, kernel) # -- Debug visualisation ------------------------------------------------------- def save_detection_debug( scene_path: str, detections: List[dict], output_path: str, ) -> None: """ Draw bounding boxes + labels on the scene image and save. `detections` is a list of dicts with keys: box (x1,y1,x2,y2), label, score. """ img = load_image_pil(scene_path) fig, ax = plt.subplots(1, figsize=(12, 8)) ax.imshow(img) colors = plt.cm.get_cmap("tab10").colors for i, det in enumerate(detections): x1, y1, x2, y2 = det["box"] color = colors[i % len(colors)] rect = mpatches.FancyBboxPatch( (x1, y1), x2 - x1, y2 - y1, boxstyle="round,pad=2", linewidth=2, edgecolor=color, facecolor="none", ) ax.add_patch(rect) ax.text( x1, y1 - 6, f"{det['label']} ({det['score']:.2f})", color="white", fontsize=9, bbox=dict(facecolor=color, alpha=0.7, pad=2, edgecolor="none"), ) ax.axis("off") plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() def save_mask_debug( scene_path: str, mask: np.ndarray, output_path: str, ) -> None: """Overlay the combined mask on the scene image (red, semi-transparent).""" img = np.array(load_image_pil(scene_path)) overlay = img.copy() overlay[mask > 0] = [255, 80, 80] blended = cv2.addWeighted(img, 0.55, overlay, 0.45, 0) save_image(Image.fromarray(blended), output_path) def save_comparison( before: Union[Image.Image, np.ndarray], after: Union[Image.Image, np.ndarray], output_path: str, labels: Tuple[str, str] = ("Before", "After"), ) -> None: """Save a side-by-side before/after comparison image.""" if isinstance(before, np.ndarray): before = cv2_to_pil(before) if isinstance(after, np.ndarray): after = cv2_to_pil(after) w = before.width + after.width + 20 h = max(before.height, after.height) + 40 canvas = Image.new("RGB", (w, h), (30, 30, 30)) canvas.paste(before, (0, 40)) canvas.paste(after, (before.width + 20, 40)) # draw labels using matplotlib to avoid font dependency fig, axes = plt.subplots(1, 2, figsize=(14, 7)) axes[0].imshow(before); axes[0].set_title(labels[0], fontsize=14); axes[0].axis("off") axes[1].imshow(after); axes[1].set_title(labels[1], fontsize=14); axes[1].axis("off") plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() # -- Checkpoint downloader ----------------------------------------------------- def download_file(url: str, dest: str, desc: str = "") -> None: """Download a file with a progress bar.""" os.makedirs(os.path.dirname(dest) or ".", exist_ok=True) if os.path.exists(dest): print(f" [DONE] Already downloaded: {os.path.basename(dest)}") return print(f" v Downloading {desc or os.path.basename(dest)} ...") response = requests.get(url, stream=True, timeout=120) response.raise_for_status() total = int(response.headers.get("content-length", 0)) with open(dest, "wb") as f, tqdm( total=total, unit="B", unit_scale=True, desc=desc or os.path.basename(dest) ) as bar: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) bar.update(len(chunk)) def download_text_file(url: str, dest: str) -> None: """Download a small text/config file.""" os.makedirs(os.path.dirname(dest) or ".", exist_ok=True) if os.path.exists(dest): return print(f" Fetching config: {os.path.basename(dest)} ...") resp = requests.get(url, timeout=30) resp.raise_for_status() with open(dest, "w") as f: f.write(resp.text) def unzip(zip_path: str, dest_dir: str) -> None: print(f" -> Extracting {os.path.basename(zip_path)} ...") with zipfile.ZipFile(zip_path, "r") as z: z.extractall(dest_dir)