""" utils/image_utils.py -------------------- Shared image loading, resizing, normalization and augmentation utilities used across all branches and training scripts. """ import cv2 import numpy as np from PIL import Image import io from typing import Tuple, Optional # ───────────────────────────────────────────── # Constants # ───────────────────────────────────────────── DEFAULT_SIZE = (224, 224) # Standard input size for CNN / ViT IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) # ───────────────────────────────────────────── # Core Loaders # ───────────────────────────────────────────── def load_image_from_path(path: str, size: Tuple[int, int] = DEFAULT_SIZE) -> np.ndarray: """ Load an image from disk and return as float32 numpy array (H, W, 3) in [0, 1]. """ img = cv2.imread(path) if img is None: raise FileNotFoundError(f"Could not read image at: {path}") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, size, interpolation=cv2.INTER_AREA) return img.astype(np.float32) / 255.0 def load_image_from_bytes(data: bytes, size: Tuple[int, int] = DEFAULT_SIZE) -> np.ndarray: """ Load an image from raw bytes (e.g., API upload) and return float32 [0, 1] array. """ pil_img = Image.open(io.BytesIO(data)).convert("RGB") pil_img = pil_img.resize(size, Image.LANCZOS) arr = np.array(pil_img, dtype=np.float32) / 255.0 return arr # ───────────────────────────────────────────── # Normalization # ───────────────────────────────────────────── def normalize_imagenet(img: np.ndarray) -> np.ndarray: """ Apply ImageNet normalization: (pixel - mean) / std. Input: float32 (H, W, 3) in [0, 1]. Output: normalized float32 (H, W, 3). """ return (img - IMAGENET_MEAN) / IMAGENET_STD def denormalize_imagenet(img: np.ndarray) -> np.ndarray: """Reverse ImageNet normalization for visualization.""" return np.clip(img * IMAGENET_STD + IMAGENET_MEAN, 0.0, 1.0) # ───────────────────────────────────────────── # Format Converters # ───────────────────────────────────────────── def to_uint8(img: np.ndarray) -> np.ndarray: """Convert float32 [0,1] to uint8 [0,255].""" return (np.clip(img, 0.0, 1.0) * 255).astype(np.uint8) def to_grayscale(img: np.ndarray) -> np.ndarray: """Convert float32 RGB (H,W,3) → grayscale (H,W) float32.""" return np.dot(img[..., :3], [0.2989, 0.5870, 0.1140]).astype(np.float32) # ───────────────────────────────────────────── # Image to Tensor helpers # ───────────────────────────────────────────── def to_tf_tensor(img: np.ndarray) -> "tf.Tensor": """ Convert (H,W,3) float32 to TensorFlow tensor with batch dim (1,H,W,3). Import TF lazily so this module doesn't hard-require TF. """ import tensorflow as tf return tf.expand_dims(tf.constant(img, dtype=tf.float32), axis=0) def to_torch_tensor(img: np.ndarray) -> "torch.Tensor": """ Convert (H,W,3) float32 to PyTorch tensor with batch+channel dim (1,3,H,W). """ import torch tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float() return tensor # ───────────────────────────────────────────── # Overlay / Visualization Helpers # ───────────────────────────────────────────── def overlay_heatmap( img: np.ndarray, heatmap: np.ndarray, alpha: float = 0.5, colormap: int = cv2.COLORMAP_JET ) -> np.ndarray: """ Overlay a heatmap on an image. img : float32 (H,W,3) in [0,1] heatmap : float32 (H,W) in [0,1] Returns : uint8 (H,W,3) blended image """ img_u8 = to_uint8(img) heat_u8 = to_uint8(heatmap[:, :, np.newaxis].repeat(3, axis=2)) heat_colored = cv2.applyColorMap(heat_u8[:, :, 0], colormap) blended = cv2.addWeighted(img_u8, 1 - alpha, heat_colored, alpha, 0) return blended def encode_image_to_base64(img_array: np.ndarray) -> str: """ Encode a uint8 numpy image (H,W,3) as a base64 JPEG string for API responses. """ import base64 _, buf = cv2.imencode(".jpg", cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)) return base64.b64encode(buf.tobytes()).decode("utf-8")