Image-Forensics-Detect / utils /image_utils.py
dk2430098's picture
Upload folder using huggingface_hub
928b74f verified
"""
utils/image_utils.py
--------------------
Shared image loading, resizing, normalization and augmentation utilities
used across all branches and training scripts.
"""
import cv2
import numpy as np
from PIL import Image
import io
from typing import Tuple, Optional
# ─────────────────────────────────────────────
# Constants
# ─────────────────────────────────────────────
DEFAULT_SIZE = (224, 224) # Standard input size for CNN / ViT
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
# ─────────────────────────────────────────────
# Core Loaders
# ─────────────────────────────────────────────
def load_image_from_path(path: str, size: Tuple[int, int] = DEFAULT_SIZE) -> np.ndarray:
"""
Load an image from disk and return as float32 numpy array (H, W, 3) in [0, 1].
"""
img = cv2.imread(path)
if img is None:
raise FileNotFoundError(f"Could not read image at: {path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
return img.astype(np.float32) / 255.0
def load_image_from_bytes(data: bytes, size: Tuple[int, int] = DEFAULT_SIZE) -> np.ndarray:
"""
Load an image from raw bytes (e.g., API upload) and return float32 [0, 1] array.
"""
pil_img = Image.open(io.BytesIO(data)).convert("RGB")
pil_img = pil_img.resize(size, Image.LANCZOS)
arr = np.array(pil_img, dtype=np.float32) / 255.0
return arr
# ─────────────────────────────────────────────
# Normalization
# ─────────────────────────────────────────────
def normalize_imagenet(img: np.ndarray) -> np.ndarray:
"""
Apply ImageNet normalization: (pixel - mean) / std.
Input: float32 (H, W, 3) in [0, 1].
Output: normalized float32 (H, W, 3).
"""
return (img - IMAGENET_MEAN) / IMAGENET_STD
def denormalize_imagenet(img: np.ndarray) -> np.ndarray:
"""Reverse ImageNet normalization for visualization."""
return np.clip(img * IMAGENET_STD + IMAGENET_MEAN, 0.0, 1.0)
# ─────────────────────────────────────────────
# Format Converters
# ─────────────────────────────────────────────
def to_uint8(img: np.ndarray) -> np.ndarray:
"""Convert float32 [0,1] to uint8 [0,255]."""
return (np.clip(img, 0.0, 1.0) * 255).astype(np.uint8)
def to_grayscale(img: np.ndarray) -> np.ndarray:
"""Convert float32 RGB (H,W,3) β†’ grayscale (H,W) float32."""
return np.dot(img[..., :3], [0.2989, 0.5870, 0.1140]).astype(np.float32)
# ─────────────────────────────────────────────
# Image to Tensor helpers
# ─────────────────────────────────────────────
def to_tf_tensor(img: np.ndarray) -> "tf.Tensor":
"""
Convert (H,W,3) float32 to TensorFlow tensor with batch dim (1,H,W,3).
Import TF lazily so this module doesn't hard-require TF.
"""
import tensorflow as tf
return tf.expand_dims(tf.constant(img, dtype=tf.float32), axis=0)
def to_torch_tensor(img: np.ndarray) -> "torch.Tensor":
"""
Convert (H,W,3) float32 to PyTorch tensor with batch+channel dim (1,3,H,W).
"""
import torch
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()
return tensor
# ─────────────────────────────────────────────
# Overlay / Visualization Helpers
# ─────────────────────────────────────────────
def overlay_heatmap(
img: np.ndarray,
heatmap: np.ndarray,
alpha: float = 0.5,
colormap: int = cv2.COLORMAP_JET
) -> np.ndarray:
"""
Overlay a heatmap on an image.
img : float32 (H,W,3) in [0,1]
heatmap : float32 (H,W) in [0,1]
Returns : uint8 (H,W,3) blended image
"""
img_u8 = to_uint8(img)
heat_u8 = to_uint8(heatmap[:, :, np.newaxis].repeat(3, axis=2))
heat_colored = cv2.applyColorMap(heat_u8[:, :, 0], colormap)
blended = cv2.addWeighted(img_u8, 1 - alpha, heat_colored, alpha, 0)
return blended
def encode_image_to_base64(img_array: np.ndarray) -> str:
"""
Encode a uint8 numpy image (H,W,3) as a base64 JPEG string for API responses.
"""
import base64
_, buf = cv2.imencode(".jpg", cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
return base64.b64encode(buf.tobytes()).decode("utf-8")