mansi-object-detector / image_utils.py
mansi-2's picture
Upload 11 files
eab4d9b verified
"""
utils/image_utils.py
--------------------
Image I/O, mask manipulation, and debug-image helpers.
"""
import os
import math
import zipfile
import requests
import urllib.request
from pathlib import Path
from typing import List, Tuple, Optional, Union
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
# -- Image I/O -----------------------------------------------------------------
def load_image_pil(path: str) -> Image.Image:
"""Load image as PIL RGB."""
return Image.open(path).convert("RGB")
def load_image_cv2(path: str) -> np.ndarray:
"""Load image as OpenCV BGR numpy array."""
img = cv2.imread(path)
if img is None:
raise FileNotFoundError(f"Cannot read image: {path}")
return img
def pil_to_cv2(img: Image.Image) -> np.ndarray:
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
def cv2_to_pil(img: np.ndarray) -> Image.Image:
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# Alias for compatibility
load_image = load_image_pil
def show_mask(mask, ax, random_color=False):
"""Stub for SAM visualization."""
pass
def show_box(box, ax):
"""Stub for SAM visualization."""
pass
def dilate_mask_with_sam_prediction(mask, dilation_px):
"""Stub for SAM-based dilation."""
return mask
def save_image(img: Union[Image.Image, np.ndarray], path: str) -> None:
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
if isinstance(img, np.ndarray):
cv2.imwrite(path, img)
else:
img.save(path)
def list_images(directory: str) -> List[str]:
"""Return sorted list of image file paths in a directory."""
exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tiff"}
paths = sorted(
str(p) for p in Path(directory).iterdir()
if p.suffix.lower() in exts
)
return paths
# -- Mask operations -----------------------------------------------------------
def boxes_to_mask(
boxes: List[Tuple[int, int, int, int]],
h: int,
w: int,
dilation_px: int = 0,
) -> np.ndarray:
"""
Convert list of (x1,y1,x2,y2) boxes to a binary uint8 mask (HW).
Optionally dilate the mask by `dilation_px` pixels.
"""
mask = np.zeros((h, w), dtype=np.uint8)
for x1, y1, x2, y2 in boxes:
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
mask[y1:y2, x1:x2] = 255
if dilation_px > 0:
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (dilation_px * 2 + 1, dilation_px * 2 + 1)
)
mask = cv2.dilate(mask, kernel)
return mask
def combine_masks(masks: List[np.ndarray]) -> np.ndarray:
"""OR-combine a list of binary uint8 masks."""
if not masks:
raise ValueError("Empty mask list")
out = np.zeros_like(masks[0])
for m in masks:
out = cv2.bitwise_or(out, m)
return out
def refine_mask_with_sam_prediction(
raw_mask: np.ndarray,
sam_masks: List[np.ndarray],
) -> np.ndarray:
"""
Given SAM predicted masks (each boolean HW), pick the one with the
highest IoU against the raw_mask and return it as uint8.
"""
best_mask = raw_mask
best_iou = 0.0
raw_bool = raw_mask.astype(bool)
for m in sam_masks:
m_bool = m.astype(bool)
intersection = (raw_bool & m_bool).sum()
union = (raw_bool | m_bool).sum()
iou = intersection / (union + 1e-8)
if iou > best_iou:
best_iou = iou
best_mask = (m_bool.astype(np.uint8)) * 255
return best_mask
def dilate_mask(mask: np.ndarray, px: int) -> np.ndarray:
if px <= 0:
return mask
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (px * 2 + 1, px * 2 + 1)
)
return cv2.dilate(mask, kernel)
# -- Debug visualisation -------------------------------------------------------
def save_detection_debug(
scene_path: str,
detections: List[dict],
output_path: str,
) -> None:
"""
Draw bounding boxes + labels on the scene image and save.
`detections` is a list of dicts with keys: box (x1,y1,x2,y2), label, score.
"""
img = load_image_pil(scene_path)
fig, ax = plt.subplots(1, figsize=(12, 8))
ax.imshow(img)
colors = plt.cm.get_cmap("tab10").colors
for i, det in enumerate(detections):
x1, y1, x2, y2 = det["box"]
color = colors[i % len(colors)]
rect = mpatches.FancyBboxPatch(
(x1, y1), x2 - x1, y2 - y1,
boxstyle="round,pad=2",
linewidth=2, edgecolor=color, facecolor="none",
)
ax.add_patch(rect)
ax.text(
x1, y1 - 6,
f"{det['label']} ({det['score']:.2f})",
color="white", fontsize=9,
bbox=dict(facecolor=color, alpha=0.7, pad=2, edgecolor="none"),
)
ax.axis("off")
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
def save_mask_debug(
scene_path: str,
mask: np.ndarray,
output_path: str,
) -> None:
"""Overlay the combined mask on the scene image (red, semi-transparent)."""
img = np.array(load_image_pil(scene_path))
overlay = img.copy()
overlay[mask > 0] = [255, 80, 80]
blended = cv2.addWeighted(img, 0.55, overlay, 0.45, 0)
save_image(Image.fromarray(blended), output_path)
def save_comparison(
before: Union[Image.Image, np.ndarray],
after: Union[Image.Image, np.ndarray],
output_path: str,
labels: Tuple[str, str] = ("Before", "After"),
) -> None:
"""Save a side-by-side before/after comparison image."""
if isinstance(before, np.ndarray):
before = cv2_to_pil(before)
if isinstance(after, np.ndarray):
after = cv2_to_pil(after)
w = before.width + after.width + 20
h = max(before.height, after.height) + 40
canvas = Image.new("RGB", (w, h), (30, 30, 30))
canvas.paste(before, (0, 40))
canvas.paste(after, (before.width + 20, 40))
# draw labels using matplotlib to avoid font dependency
fig, axes = plt.subplots(1, 2, figsize=(14, 7))
axes[0].imshow(before); axes[0].set_title(labels[0], fontsize=14); axes[0].axis("off")
axes[1].imshow(after); axes[1].set_title(labels[1], fontsize=14); axes[1].axis("off")
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
# -- Checkpoint downloader -----------------------------------------------------
def download_file(url: str, dest: str, desc: str = "") -> None:
"""Download a file with a progress bar."""
os.makedirs(os.path.dirname(dest) or ".", exist_ok=True)
if os.path.exists(dest):
print(f" [DONE] Already downloaded: {os.path.basename(dest)}")
return
print(f" v Downloading {desc or os.path.basename(dest)} ...")
response = requests.get(url, stream=True, timeout=120)
response.raise_for_status()
total = int(response.headers.get("content-length", 0))
with open(dest, "wb") as f, tqdm(
total=total, unit="B", unit_scale=True, desc=desc or os.path.basename(dest)
) as bar:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
bar.update(len(chunk))
def download_text_file(url: str, dest: str) -> None:
"""Download a small text/config file."""
os.makedirs(os.path.dirname(dest) or ".", exist_ok=True)
if os.path.exists(dest):
return
print(f" Fetching config: {os.path.basename(dest)} ...")
resp = requests.get(url, timeout=30)
resp.raise_for_status()
with open(dest, "w") as f:
f.write(resp.text)
def unzip(zip_path: str, dest_dir: str) -> None:
print(f" -> Extracting {os.path.basename(zip_path)} ...")
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(dest_dir)