File size: 3,539 Bytes
3fffbdc
 
 
 
 
 
 
4b27b63
 
 
 
 
 
3fffbdc
4b27b63
3fffbdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b27b63
3fffbdc
 
 
 
 
 
 
4b27b63
 
3fffbdc
4b27b63
 
3fffbdc
4b27b63
3fffbdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b27b63
3fffbdc
 
 
 
 
 
 
4b27b63
 
3fffbdc
4b27b63
 
3fffbdc
4b27b63
 
3fffbdc
4b27b63
 
 
3fffbdc
4b27b63
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
image_processor.py

Utility functions for image preprocessing used in the mosaic generator:
- Cropping an image so it's divisible by the grid
- Computing LAB cell means for FAISS-based tile matching
"""

import numpy as np
import cv2

from .utils import fast_rgb2lab


def crop_to_multiple(img, grid_n):
    """
    Crop an RGB image so that its width and height are perfectly divisible
    by the chosen grid size.

    Parameters
    ----------
    img : np.ndarray
        RGB image array of shape (H, W, 3).
    grid_n : int
        Number of cells per side in the mosaic grid.

    Returns
    -------
    np.ndarray
        Cropped RGB image whose dimensions are multiples of `grid_n`.

    Raises
    ------
    ValueError
        If `img` is not a valid image array or grid size is invalid.

    Notes
    -----
    This does NOT resize the image — it simply trims extra pixels so that
    (H % grid_n == 0) and (W % grid_n == 0).
    """
    if img is None or not isinstance(img, np.ndarray):
        raise ValueError("Input image must be a valid NumPy RGB array.")

    if img.ndim != 3 or img.shape[2] != 3:
        raise ValueError(f"Expected image shape (H, W, 3), got {img.shape}.")

    if not isinstance(grid_n, int) or grid_n <= 0:
        raise ValueError("grid_n must be a positive integer.")

    h, w = img.shape[:2]

    if h < grid_n or w < grid_n:
        raise ValueError(
            f"Image too small for grid size {grid_n}. "
            f"Received image of size {w}x{h}."
        )

    new_w = (w // grid_n) * grid_n
    new_h = (h // grid_n) * grid_n

    return img[:new_h, :new_w]


def compute_cell_means_lab(img, grid_n):
    """
    Compute LAB mean color for each grid cell in the image.

    Parameters
    ----------
    img : np.ndarray
        Cropped RGB image array (H, W, 3).
    grid_n : int
        Grid size — number of cells per side.

    Returns
    -------
    means : np.ndarray
        Array of shape (grid_n * grid_n, 3). LAB mean per grid cell.
    dims : tuple
        (W, H, cell_w, cell_h)

        - W, H  : final image dimensions
        - cell_w, cell_h : size of each grid cell in pixels

    Raises
    ------
    ValueError
        If the image is not divisible by grid_n, or has unexpected shape.

    Notes
    -----
    The function converts the full image to LAB **once**, then extracts
    block means efficiently without redundant conversions.
    """
    if img is None or not isinstance(img, np.ndarray):
        raise ValueError("Input image must be a valid NumPy RGB array.")

    if img.ndim != 3 or img.shape[2] != 3:
        raise ValueError(f"Expected RGB image with 3 channels, got {img.shape}.")

    if not isinstance(grid_n, int) or grid_n <= 0:
        raise ValueError("grid_n must be a positive integer.")

    h, w = img.shape[:2]

    if h % grid_n != 0 or w % grid_n != 0:
        raise ValueError(
            f"Image size ({w}x{h}) is not divisible by grid size {grid_n}. "
            "Call crop_to_multiple() first."
        )

    cell_h, cell_w = h // grid_n, w // grid_n

    # Single conversion for full image
    lab = fast_rgb2lab(img)

    # Output: N cells × 3 channels
    means = np.zeros((grid_n * grid_n, 3), dtype=np.float32)
    k = 0

    for gy in range(grid_n):
        for gx in range(grid_n):
            block = lab[gy*cell_h:(gy+1)*cell_h, gx*cell_w:(gx+1)*cell_w]
            # Safe flatten + mean
            means[k] = block.reshape(-1, 3).mean(axis=0)
            k += 1

    return means, (w, h, cell_w, cell_h)