File size: 5,319 Bytes
928b74f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
utils/image_utils.py
--------------------
Shared image loading, resizing, normalization and augmentation utilities
used across all branches and training scripts.
"""

import cv2
import numpy as np
from PIL import Image
import io
from typing import Tuple, Optional


# ─────────────────────────────────────────────
# Constants
# ─────────────────────────────────────────────
DEFAULT_SIZE = (224, 224)   # Standard input size for CNN / ViT
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)


# ─────────────────────────────────────────────
# Core Loaders
# ─────────────────────────────────────────────

def load_image_from_path(path: str, size: Tuple[int, int] = DEFAULT_SIZE) -> np.ndarray:
    """
    Load an image from disk and return as float32 numpy array (H, W, 3) in [0, 1].
    """
    img = cv2.imread(path)
    if img is None:
        raise FileNotFoundError(f"Could not read image at: {path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
    return img.astype(np.float32) / 255.0


def load_image_from_bytes(data: bytes, size: Tuple[int, int] = DEFAULT_SIZE) -> np.ndarray:
    """
    Load an image from raw bytes (e.g., API upload) and return float32 [0, 1] array.
    """
    pil_img = Image.open(io.BytesIO(data)).convert("RGB")
    pil_img = pil_img.resize(size, Image.LANCZOS)
    arr = np.array(pil_img, dtype=np.float32) / 255.0
    return arr


# ─────────────────────────────────────────────
# Normalization
# ─────────────────────────────────────────────

def normalize_imagenet(img: np.ndarray) -> np.ndarray:
    """
    Apply ImageNet normalization: (pixel - mean) / std.
    Input: float32 (H, W, 3) in [0, 1].
    Output: normalized float32 (H, W, 3).
    """
    return (img - IMAGENET_MEAN) / IMAGENET_STD


def denormalize_imagenet(img: np.ndarray) -> np.ndarray:
    """Reverse ImageNet normalization for visualization."""
    return np.clip(img * IMAGENET_STD + IMAGENET_MEAN, 0.0, 1.0)


# ─────────────────────────────────────────────
# Format Converters
# ─────────────────────────────────────────────

def to_uint8(img: np.ndarray) -> np.ndarray:
    """Convert float32 [0,1] to uint8 [0,255]."""
    return (np.clip(img, 0.0, 1.0) * 255).astype(np.uint8)


def to_grayscale(img: np.ndarray) -> np.ndarray:
    """Convert float32 RGB (H,W,3) β†’ grayscale (H,W) float32."""
    return np.dot(img[..., :3], [0.2989, 0.5870, 0.1140]).astype(np.float32)


# ─────────────────────────────────────────────
# Image to Tensor helpers
# ─────────────────────────────────────────────

def to_tf_tensor(img: np.ndarray) -> "tf.Tensor":
    """
    Convert (H,W,3) float32 to TensorFlow tensor with batch dim (1,H,W,3).
    Import TF lazily so this module doesn't hard-require TF.
    """
    import tensorflow as tf
    return tf.expand_dims(tf.constant(img, dtype=tf.float32), axis=0)


def to_torch_tensor(img: np.ndarray) -> "torch.Tensor":
    """
    Convert (H,W,3) float32 to PyTorch tensor with batch+channel dim (1,3,H,W).
    """
    import torch
    tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()
    return tensor


# ─────────────────────────────────────────────
# Overlay / Visualization Helpers
# ─────────────────────────────────────────────

def overlay_heatmap(
    img: np.ndarray,
    heatmap: np.ndarray,
    alpha: float = 0.5,
    colormap: int = cv2.COLORMAP_JET
) -> np.ndarray:
    """
    Overlay a heatmap on an image.
    img     : float32 (H,W,3) in [0,1]
    heatmap : float32 (H,W)   in [0,1]
    Returns : uint8 (H,W,3) blended image
    """
    img_u8 = to_uint8(img)
    heat_u8 = to_uint8(heatmap[:, :, np.newaxis].repeat(3, axis=2))
    heat_colored = cv2.applyColorMap(heat_u8[:, :, 0], colormap)
    blended = cv2.addWeighted(img_u8, 1 - alpha, heat_colored, alpha, 0)
    return blended


def encode_image_to_base64(img_array: np.ndarray) -> str:
    """
    Encode a uint8 numpy image (H,W,3) as a base64 JPEG string for API responses.
    """
    import base64
    _, buf = cv2.imencode(".jpg", cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
    return base64.b64encode(buf.tobytes()).decode("utf-8")