File size: 5,319 Bytes
928b74f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | """
utils/image_utils.py
--------------------
Shared image loading, resizing, normalization and augmentation utilities
used across all branches and training scripts.
"""
import cv2
import numpy as np
from PIL import Image
import io
from typing import Tuple, Optional
# βββββββββββββββββββββββββββββββββββββββββββββ
# Constants
# βββββββββββββββββββββββββββββββββββββββββββββ
DEFAULT_SIZE = (224, 224) # Standard input size for CNN / ViT
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
# βββββββββββββββββββββββββββββββββββββββββββββ
# Core Loaders
# βββββββββββββββββββββββββββββββββββββββββββββ
def load_image_from_path(path: str, size: Tuple[int, int] = DEFAULT_SIZE) -> np.ndarray:
"""
Load an image from disk and return as float32 numpy array (H, W, 3) in [0, 1].
"""
img = cv2.imread(path)
if img is None:
raise FileNotFoundError(f"Could not read image at: {path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
return img.astype(np.float32) / 255.0
def load_image_from_bytes(data: bytes, size: Tuple[int, int] = DEFAULT_SIZE) -> np.ndarray:
"""
Load an image from raw bytes (e.g., API upload) and return float32 [0, 1] array.
"""
pil_img = Image.open(io.BytesIO(data)).convert("RGB")
pil_img = pil_img.resize(size, Image.LANCZOS)
arr = np.array(pil_img, dtype=np.float32) / 255.0
return arr
# βββββββββββββββββββββββββββββββββββββββββββββ
# Normalization
# βββββββββββββββββββββββββββββββββββββββββββββ
def normalize_imagenet(img: np.ndarray) -> np.ndarray:
"""
Apply ImageNet normalization: (pixel - mean) / std.
Input: float32 (H, W, 3) in [0, 1].
Output: normalized float32 (H, W, 3).
"""
return (img - IMAGENET_MEAN) / IMAGENET_STD
def denormalize_imagenet(img: np.ndarray) -> np.ndarray:
"""Reverse ImageNet normalization for visualization."""
return np.clip(img * IMAGENET_STD + IMAGENET_MEAN, 0.0, 1.0)
# βββββββββββββββββββββββββββββββββββββββββββββ
# Format Converters
# βββββββββββββββββββββββββββββββββββββββββββββ
def to_uint8(img: np.ndarray) -> np.ndarray:
"""Convert float32 [0,1] to uint8 [0,255]."""
return (np.clip(img, 0.0, 1.0) * 255).astype(np.uint8)
def to_grayscale(img: np.ndarray) -> np.ndarray:
"""Convert float32 RGB (H,W,3) β grayscale (H,W) float32."""
return np.dot(img[..., :3], [0.2989, 0.5870, 0.1140]).astype(np.float32)
# βββββββββββββββββββββββββββββββββββββββββββββ
# Image to Tensor helpers
# βββββββββββββββββββββββββββββββββββββββββββββ
def to_tf_tensor(img: np.ndarray) -> "tf.Tensor":
"""
Convert (H,W,3) float32 to TensorFlow tensor with batch dim (1,H,W,3).
Import TF lazily so this module doesn't hard-require TF.
"""
import tensorflow as tf
return tf.expand_dims(tf.constant(img, dtype=tf.float32), axis=0)
def to_torch_tensor(img: np.ndarray) -> "torch.Tensor":
"""
Convert (H,W,3) float32 to PyTorch tensor with batch+channel dim (1,3,H,W).
"""
import torch
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()
return tensor
# βββββββββββββββββββββββββββββββββββββββββββββ
# Overlay / Visualization Helpers
# βββββββββββββββββββββββββββββββββββββββββββββ
def overlay_heatmap(
img: np.ndarray,
heatmap: np.ndarray,
alpha: float = 0.5,
colormap: int = cv2.COLORMAP_JET
) -> np.ndarray:
"""
Overlay a heatmap on an image.
img : float32 (H,W,3) in [0,1]
heatmap : float32 (H,W) in [0,1]
Returns : uint8 (H,W,3) blended image
"""
img_u8 = to_uint8(img)
heat_u8 = to_uint8(heatmap[:, :, np.newaxis].repeat(3, axis=2))
heat_colored = cv2.applyColorMap(heat_u8[:, :, 0], colormap)
blended = cv2.addWeighted(img_u8, 1 - alpha, heat_colored, alpha, 0)
return blended
def encode_image_to_base64(img_array: np.ndarray) -> str:
"""
Encode a uint8 numpy image (H,W,3) as a base64 JPEG string for API responses.
"""
import base64
_, buf = cv2.imencode(".jpg", cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
return base64.b64encode(buf.tobytes()).decode("utf-8")
|