Spaces:
Runtime error
Runtime error
| """ | |
| utils/image_utils.py | |
| -------------------- | |
| Image I/O, mask manipulation, and debug-image helpers. | |
| """ | |
| import os | |
| import math | |
| import zipfile | |
| import requests | |
| import urllib.request | |
| from pathlib import Path | |
| from typing import List, Tuple, Optional, Union | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| # -- Image I/O ----------------------------------------------------------------- | |
| def load_image_pil(path: str) -> Image.Image: | |
| """Load image as PIL RGB.""" | |
| return Image.open(path).convert("RGB") | |
| def load_image_cv2(path: str) -> np.ndarray: | |
| """Load image as OpenCV BGR numpy array.""" | |
| img = cv2.imread(path) | |
| if img is None: | |
| raise FileNotFoundError(f"Cannot read image: {path}") | |
| return img | |
| def pil_to_cv2(img: Image.Image) -> np.ndarray: | |
| return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| def cv2_to_pil(img: np.ndarray) -> Image.Image: | |
| return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
| # Alias for compatibility | |
| load_image = load_image_pil | |
| def show_mask(mask, ax, random_color=False): | |
| """Stub for SAM visualization.""" | |
| pass | |
| def show_box(box, ax): | |
| """Stub for SAM visualization.""" | |
| pass | |
| def dilate_mask_with_sam_prediction(mask, dilation_px): | |
| """Stub for SAM-based dilation.""" | |
| return mask | |
| def save_image(img: Union[Image.Image, np.ndarray], path: str) -> None: | |
| os.makedirs(os.path.dirname(path) or ".", exist_ok=True) | |
| if isinstance(img, np.ndarray): | |
| cv2.imwrite(path, img) | |
| else: | |
| img.save(path) | |
| def list_images(directory: str) -> List[str]: | |
| """Return sorted list of image file paths in a directory.""" | |
| exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff"} | |
| paths = sorted( | |
| str(p) for p in Path(directory).iterdir() | |
| if p.suffix.lower() in exts | |
| ) | |
| return paths | |
| # -- Mask operations ----------------------------------------------------------- | |
| def boxes_to_mask( | |
| boxes: List[Tuple[int, int, int, int]], | |
| h: int, | |
| w: int, | |
| dilation_px: int = 0, | |
| ) -> np.ndarray: | |
| """ | |
| Convert list of (x1,y1,x2,y2) boxes to a binary uint8 mask (HW). | |
| Optionally dilate the mask by `dilation_px` pixels. | |
| """ | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| for x1, y1, x2, y2 in boxes: | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(w, x2), min(h, y2) | |
| mask[y1:y2, x1:x2] = 255 | |
| if dilation_px > 0: | |
| kernel = cv2.getStructuringElement( | |
| cv2.MORPH_ELLIPSE, (dilation_px * 2 + 1, dilation_px * 2 + 1) | |
| ) | |
| mask = cv2.dilate(mask, kernel) | |
| return mask | |
| def combine_masks(masks: List[np.ndarray]) -> np.ndarray: | |
| """OR-combine a list of binary uint8 masks.""" | |
| if not masks: | |
| raise ValueError("Empty mask list") | |
| out = np.zeros_like(masks[0]) | |
| for m in masks: | |
| out = cv2.bitwise_or(out, m) | |
| return out | |
| def refine_mask_with_sam_prediction( | |
| raw_mask: np.ndarray, | |
| sam_masks: List[np.ndarray], | |
| ) -> np.ndarray: | |
| """ | |
| Given SAM predicted masks (each boolean HW), pick the one with the | |
| highest IoU against the raw_mask and return it as uint8. | |
| """ | |
| best_mask = raw_mask | |
| best_iou = 0.0 | |
| raw_bool = raw_mask.astype(bool) | |
| for m in sam_masks: | |
| m_bool = m.astype(bool) | |
| intersection = (raw_bool & m_bool).sum() | |
| union = (raw_bool | m_bool).sum() | |
| iou = intersection / (union + 1e-8) | |
| if iou > best_iou: | |
| best_iou = iou | |
| best_mask = (m_bool.astype(np.uint8)) * 255 | |
| return best_mask | |
| def dilate_mask(mask: np.ndarray, px: int) -> np.ndarray: | |
| if px <= 0: | |
| return mask | |
| kernel = cv2.getStructuringElement( | |
| cv2.MORPH_ELLIPSE, (px * 2 + 1, px * 2 + 1) | |
| ) | |
| return cv2.dilate(mask, kernel) | |
| # -- Debug visualisation ------------------------------------------------------- | |
| def save_detection_debug( | |
| scene_path: str, | |
| detections: List[dict], | |
| output_path: str, | |
| ) -> None: | |
| """ | |
| Draw bounding boxes + labels on the scene image and save. | |
| `detections` is a list of dicts with keys: box (x1,y1,x2,y2), label, score. | |
| """ | |
| img = load_image_pil(scene_path) | |
| fig, ax = plt.subplots(1, figsize=(12, 8)) | |
| ax.imshow(img) | |
| colors = plt.cm.get_cmap("tab10").colors | |
| for i, det in enumerate(detections): | |
| x1, y1, x2, y2 = det["box"] | |
| color = colors[i % len(colors)] | |
| rect = mpatches.FancyBboxPatch( | |
| (x1, y1), x2 - x1, y2 - y1, | |
| boxstyle="round,pad=2", | |
| linewidth=2, edgecolor=color, facecolor="none", | |
| ) | |
| ax.add_patch(rect) | |
| ax.text( | |
| x1, y1 - 6, | |
| f"{det['label']} ({det['score']:.2f})", | |
| color="white", fontsize=9, | |
| bbox=dict(facecolor=color, alpha=0.7, pad=2, edgecolor="none"), | |
| ) | |
| ax.axis("off") | |
| plt.tight_layout() | |
| plt.savefig(output_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| def save_mask_debug( | |
| scene_path: str, | |
| mask: np.ndarray, | |
| output_path: str, | |
| ) -> None: | |
| """Overlay the combined mask on the scene image (red, semi-transparent).""" | |
| img = np.array(load_image_pil(scene_path)) | |
| overlay = img.copy() | |
| overlay[mask > 0] = [255, 80, 80] | |
| blended = cv2.addWeighted(img, 0.55, overlay, 0.45, 0) | |
| save_image(Image.fromarray(blended), output_path) | |
| def save_comparison( | |
| before: Union[Image.Image, np.ndarray], | |
| after: Union[Image.Image, np.ndarray], | |
| output_path: str, | |
| labels: Tuple[str, str] = ("Before", "After"), | |
| ) -> None: | |
| """Save a side-by-side before/after comparison image.""" | |
| if isinstance(before, np.ndarray): | |
| before = cv2_to_pil(before) | |
| if isinstance(after, np.ndarray): | |
| after = cv2_to_pil(after) | |
| w = before.width + after.width + 20 | |
| h = max(before.height, after.height) + 40 | |
| canvas = Image.new("RGB", (w, h), (30, 30, 30)) | |
| canvas.paste(before, (0, 40)) | |
| canvas.paste(after, (before.width + 20, 40)) | |
| # draw labels using matplotlib to avoid font dependency | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 7)) | |
| axes[0].imshow(before); axes[0].set_title(labels[0], fontsize=14); axes[0].axis("off") | |
| axes[1].imshow(after); axes[1].set_title(labels[1], fontsize=14); axes[1].axis("off") | |
| plt.tight_layout() | |
| plt.savefig(output_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| # -- Checkpoint downloader ----------------------------------------------------- | |
| def download_file(url: str, dest: str, desc: str = "") -> None: | |
| """Download a file with a progress bar.""" | |
| os.makedirs(os.path.dirname(dest) or ".", exist_ok=True) | |
| if os.path.exists(dest): | |
| print(f" [DONE] Already downloaded: {os.path.basename(dest)}") | |
| return | |
| print(f" v Downloading {desc or os.path.basename(dest)} ...") | |
| response = requests.get(url, stream=True, timeout=120) | |
| response.raise_for_status() | |
| total = int(response.headers.get("content-length", 0)) | |
| with open(dest, "wb") as f, tqdm( | |
| total=total, unit="B", unit_scale=True, desc=desc or os.path.basename(dest) | |
| ) as bar: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| bar.update(len(chunk)) | |
| def download_text_file(url: str, dest: str) -> None: | |
| """Download a small text/config file.""" | |
| os.makedirs(os.path.dirname(dest) or ".", exist_ok=True) | |
| if os.path.exists(dest): | |
| return | |
| print(f" Fetching config: {os.path.basename(dest)} ...") | |
| resp = requests.get(url, timeout=30) | |
| resp.raise_for_status() | |
| with open(dest, "w") as f: | |
| f.write(resp.text) | |
| def unzip(zip_path: str, dest_dir: str) -> None: | |
| print(f" -> Extracting {os.path.basename(zip_path)} ...") | |
| with zipfile.ZipFile(zip_path, "r") as z: | |
| z.extractall(dest_dir) | |