|
|
import numpy as np
|
|
|
import cv2
|
|
|
from sklearn.cluster import KMeans
|
|
|
from sklearn.preprocessing import LabelEncoder
|
|
|
import matplotlib.pyplot as plt
|
|
|
from typing import Tuple, List, Dict, Optional
|
|
|
from dataclasses import dataclass
|
|
|
from enum import Enum
|
|
|
import seaborn as sns
|
|
|
|
|
|
class ColorClassificationMethod(Enum):
|
|
|
"""Different methods for classifying cell colors."""
|
|
|
DOMINANT_COLOR = "dominant_color"
|
|
|
AVERAGE_COLOR = "average_color"
|
|
|
HISTOGRAM_BINS = "histogram_bins"
|
|
|
HSV_QUANTIZATION = "hsv_quantization"
|
|
|
|
|
|
@dataclass
|
|
|
class GridCell:
|
|
|
"""Represents a single grid cell with its properties."""
|
|
|
row: int
|
|
|
col: int
|
|
|
average_color: np.ndarray
|
|
|
dominant_color: np.ndarray
|
|
|
brightness: float
|
|
|
saturation: float
|
|
|
hue: float
|
|
|
color_category: int
|
|
|
pixel_data: np.ndarray
|
|
|
|
|
|
class ImageGridAnalyzer:
|
|
|
"""
|
|
|
Analyzes images by dividing them into grids and classifying each cell's color properties.
|
|
|
Uses vectorized NumPy operations for high performance.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, grid_size: Tuple[int, int] = (32, 32),
|
|
|
classification_method: ColorClassificationMethod = ColorClassificationMethod.DOMINANT_COLOR,
|
|
|
n_color_categories: int = 16):
|
|
|
"""
|
|
|
Initialize the grid analyzer.
|
|
|
|
|
|
Args:
|
|
|
grid_size: (rows, cols) for the grid division
|
|
|
classification_method: Method to classify cell colors
|
|
|
n_color_categories: Number of color categories for classification
|
|
|
"""
|
|
|
self.grid_size = grid_size
|
|
|
self.classification_method = classification_method
|
|
|
self.n_color_categories = n_color_categories
|
|
|
self.color_classifier = None
|
|
|
self.category_colors = None
|
|
|
|
|
|
def divide_image_into_grid(self, image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
|
|
|
"""
|
|
|
Divide image into a grid using vectorized operations.
|
|
|
|
|
|
Args:
|
|
|
image: Input image (H, W, C)
|
|
|
|
|
|
Returns:
|
|
|
Grid of cells (grid_rows, grid_cols, tile_height, tile_width, channels)
|
|
|
Tuple of (tile_height, tile_width)
|
|
|
"""
|
|
|
h, w, c = image.shape
|
|
|
grid_rows, grid_cols = self.grid_size
|
|
|
|
|
|
|
|
|
tile_h = h // grid_rows
|
|
|
tile_w = w // grid_cols
|
|
|
|
|
|
|
|
|
adjusted_h = tile_h * grid_rows
|
|
|
adjusted_w = tile_w * grid_cols
|
|
|
image = image[:adjusted_h, :adjusted_w]
|
|
|
|
|
|
|
|
|
|
|
|
grid = image.reshape(grid_rows, tile_h, grid_cols, tile_w, c)
|
|
|
grid = grid.transpose(0, 2, 1, 3, 4)
|
|
|
|
|
|
return grid, (tile_h, tile_w)
|
|
|
|
|
|
def analyze_grid_colors_vectorized(self, grid: np.ndarray) -> Dict[str, np.ndarray]:
|
|
|
"""
|
|
|
Analyze color properties of all grid cells using vectorized operations.
|
|
|
|
|
|
Args:
|
|
|
grid: Grid of cells (grid_rows, grid_cols, tile_h, tile_w, c)
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing vectorized analysis results
|
|
|
"""
|
|
|
grid_rows, grid_cols, tile_h, tile_w, c = grid.shape
|
|
|
|
|
|
|
|
|
cells_flat = grid.reshape(grid_rows * grid_cols, tile_h * tile_w, c)
|
|
|
|
|
|
|
|
|
average_colors = np.mean(cells_flat, axis=1)
|
|
|
|
|
|
|
|
|
dominant_colors = self._calculate_dominant_colors_vectorized(cells_flat)
|
|
|
|
|
|
|
|
|
hsv_averages = self._rgb_to_hsv_vectorized(average_colors)
|
|
|
|
|
|
|
|
|
brightness = hsv_averages[:, 2]
|
|
|
|
|
|
|
|
|
saturation = hsv_averages[:, 1]
|
|
|
|
|
|
|
|
|
hue = hsv_averages[:, 0]
|
|
|
|
|
|
|
|
|
results = {
|
|
|
'average_colors': average_colors.reshape(grid_rows, grid_cols, c),
|
|
|
'dominant_colors': dominant_colors.reshape(grid_rows, grid_cols, c),
|
|
|
'brightness': brightness.reshape(grid_rows, grid_cols),
|
|
|
'saturation': saturation.reshape(grid_rows, grid_cols),
|
|
|
'hue': hue.reshape(grid_rows, grid_cols),
|
|
|
'cells_data': grid
|
|
|
}
|
|
|
|
|
|
return results
|
|
|
|
|
|
def _calculate_dominant_colors_vectorized(self, cells_flat: np.ndarray) -> np.ndarray:
|
|
|
"""
|
|
|
Calculate dominant color for each cell using vectorized operations.
|
|
|
|
|
|
Args:
|
|
|
cells_flat: Flattened cells (total_cells, pixels_per_cell, channels)
|
|
|
|
|
|
Returns:
|
|
|
Dominant colors for all cells (total_cells, channels)
|
|
|
"""
|
|
|
import warnings
|
|
|
|
|
|
total_cells, pixels_per_cell, c = cells_flat.shape
|
|
|
dominant_colors = np.zeros((total_cells, c))
|
|
|
|
|
|
|
|
|
batch_size = 100
|
|
|
for i in range(0, total_cells, batch_size):
|
|
|
end_idx = min(i + batch_size, total_cells)
|
|
|
batch = cells_flat[i:end_idx]
|
|
|
|
|
|
for j, cell_pixels in enumerate(batch):
|
|
|
|
|
|
unique_pixels = np.unique(cell_pixels, axis=0)
|
|
|
|
|
|
if len(unique_pixels) >= 3 and pixels_per_cell > 100:
|
|
|
|
|
|
with warnings.catch_warnings():
|
|
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
warnings.filterwarnings("ignore", message=".*ConvergenceWarning.*")
|
|
|
|
|
|
kmeans = KMeans(n_clusters=min(3, len(unique_pixels)),
|
|
|
random_state=42, n_init=5)
|
|
|
labels = kmeans.fit_predict(cell_pixels)
|
|
|
|
|
|
unique_labels, counts = np.unique(labels, return_counts=True)
|
|
|
dominant_idx = unique_labels[np.argmax(counts)]
|
|
|
dominant_colors[i + j] = kmeans.cluster_centers_[dominant_idx]
|
|
|
elif len(unique_pixels) >= 2:
|
|
|
|
|
|
unique_colors, counts = np.unique(cell_pixels, axis=0, return_counts=True)
|
|
|
dominant_colors[i + j] = unique_colors[np.argmax(counts)]
|
|
|
else:
|
|
|
|
|
|
dominant_colors[i + j] = np.mean(cell_pixels, axis=0)
|
|
|
|
|
|
return dominant_colors
|
|
|
|
|
|
def _rgb_to_hsv_vectorized(self, rgb_colors: np.ndarray) -> np.ndarray:
|
|
|
"""
|
|
|
Convert RGB colors to HSV using vectorized operations.
|
|
|
|
|
|
Args:
|
|
|
rgb_colors: RGB colors (N, 3)
|
|
|
|
|
|
Returns:
|
|
|
HSV colors (N, 3)
|
|
|
"""
|
|
|
|
|
|
rgb_normalized = rgb_colors / 255.0
|
|
|
|
|
|
|
|
|
dummy_img = rgb_normalized.reshape(-1, 1, 3).astype(np.float32)
|
|
|
hsv_img = cv2.cvtColor(dummy_img, cv2.COLOR_RGB2HSV)
|
|
|
hsv_colors = hsv_img.reshape(-1, 3)
|
|
|
|
|
|
return hsv_colors
|
|
|
|
|
|
def classify_colors(self, color_data: Dict[str, np.ndarray]) -> np.ndarray:
|
|
|
"""
|
|
|
Classify each grid cell into color categories.
|
|
|
|
|
|
Args:
|
|
|
color_data: Dictionary containing color analysis results
|
|
|
|
|
|
Returns:
|
|
|
Color categories for each grid cell (grid_rows, grid_cols)
|
|
|
"""
|
|
|
import warnings
|
|
|
|
|
|
if self.classification_method == ColorClassificationMethod.AVERAGE_COLOR:
|
|
|
features = color_data['average_colors']
|
|
|
elif self.classification_method == ColorClassificationMethod.DOMINANT_COLOR:
|
|
|
features = color_data['dominant_colors']
|
|
|
elif self.classification_method == ColorClassificationMethod.HSV_QUANTIZATION:
|
|
|
|
|
|
h = color_data['hue']
|
|
|
s = color_data['saturation']
|
|
|
v = color_data['brightness']
|
|
|
features = np.stack([h, s, v], axis=-1)
|
|
|
else:
|
|
|
features = color_data['average_colors']
|
|
|
|
|
|
|
|
|
grid_rows, grid_cols = features.shape[:2]
|
|
|
features_flat = features.reshape(-1, features.shape[-1])
|
|
|
|
|
|
|
|
|
unique_features = np.unique(features_flat, axis=0)
|
|
|
effective_clusters = min(self.n_color_categories, len(unique_features))
|
|
|
|
|
|
if effective_clusters < 2:
|
|
|
|
|
|
print(f"Warning: Only {len(unique_features)} unique colors found. Using simple classification.")
|
|
|
categories = np.zeros(len(features_flat), dtype=int)
|
|
|
categories_grid = categories.reshape(grid_rows, grid_cols)
|
|
|
self.category_colors = unique_features[:1] if len(unique_features) > 0 else np.array([[128, 128, 128]])
|
|
|
return categories_grid
|
|
|
|
|
|
|
|
|
with warnings.catch_warnings():
|
|
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
warnings.filterwarnings("ignore", message=".*ConvergenceWarning.*")
|
|
|
|
|
|
self.color_classifier = KMeans(n_clusters=effective_clusters,
|
|
|
random_state=42, n_init=10)
|
|
|
categories = self.color_classifier.fit_predict(features_flat)
|
|
|
|
|
|
|
|
|
self.category_colors = self.color_classifier.cluster_centers_
|
|
|
|
|
|
|
|
|
categories_grid = categories.reshape(grid_rows, grid_cols)
|
|
|
|
|
|
return categories_grid
|
|
|
|
|
|
def apply_thresholding(self, color_data: Dict[str, np.ndarray],
|
|
|
brightness_threshold: float = 0.5,
|
|
|
saturation_threshold: float = 0.3) -> Dict[str, np.ndarray]:
|
|
|
"""
|
|
|
Apply thresholding to create binary masks for different criteria.
|
|
|
|
|
|
Args:
|
|
|
color_data: Color analysis results
|
|
|
brightness_threshold: Threshold for bright/dark classification
|
|
|
saturation_threshold: Threshold for saturated/desaturated classification
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing various threshold masks
|
|
|
"""
|
|
|
brightness = color_data['brightness']
|
|
|
saturation = color_data['saturation']
|
|
|
|
|
|
|
|
|
brightness_norm = brightness / 255.0 if brightness.max() > 1.0 else brightness
|
|
|
saturation_norm = saturation / 255.0 if saturation.max() > 1.0 else saturation
|
|
|
|
|
|
thresholds = {
|
|
|
'bright_mask': brightness_norm > brightness_threshold,
|
|
|
'dark_mask': brightness_norm <= brightness_threshold,
|
|
|
'saturated_mask': saturation_norm > saturation_threshold,
|
|
|
'desaturated_mask': saturation_norm <= saturation_threshold,
|
|
|
'bright_saturated': (brightness_norm > brightness_threshold) &
|
|
|
(saturation_norm > saturation_threshold),
|
|
|
'dark_saturated': (brightness_norm <= brightness_threshold) &
|
|
|
(saturation_norm > saturation_threshold)
|
|
|
}
|
|
|
|
|
|
return thresholds
|
|
|
|
|
|
def analyze_image_complete(self, image: np.ndarray) -> Dict:
|
|
|
"""
|
|
|
Complete analysis pipeline for an image.
|
|
|
|
|
|
Args:
|
|
|
image: Input image (H, W, C)
|
|
|
|
|
|
Returns:
|
|
|
Complete analysis results
|
|
|
"""
|
|
|
print(f"Analyzing image with {self.grid_size[0]}x{self.grid_size[1]} grid...")
|
|
|
|
|
|
|
|
|
grid, tile_size = self.divide_image_into_grid(image)
|
|
|
print(f"Created grid with tile size: {tile_size}")
|
|
|
|
|
|
|
|
|
color_data = self.analyze_grid_colors_vectorized(grid)
|
|
|
print("Completed color analysis")
|
|
|
|
|
|
|
|
|
color_categories = self.classify_colors(color_data)
|
|
|
print(f"Classified into {self.n_color_categories} color categories")
|
|
|
|
|
|
|
|
|
thresholds = self.apply_thresholding(color_data)
|
|
|
print("Applied thresholding")
|
|
|
|
|
|
|
|
|
results = {
|
|
|
'grid': grid,
|
|
|
'tile_size': tile_size,
|
|
|
'color_data': color_data,
|
|
|
'color_categories': color_categories,
|
|
|
'thresholds': thresholds,
|
|
|
'category_colors': self.category_colors
|
|
|
}
|
|
|
|
|
|
return results
|
|
|
|
|
|
def visualize_analysis(self, results: Dict, original_image: np.ndarray):
|
|
|
"""
|
|
|
Create comprehensive visualizations of the analysis results.
|
|
|
|
|
|
Args:
|
|
|
results: Analysis results from analyze_image_complete
|
|
|
original_image: Original input image
|
|
|
"""
|
|
|
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
|
|
|
|
|
|
|
|
|
axes[0, 0].imshow(original_image)
|
|
|
axes[0, 0].set_title('Original Image')
|
|
|
axes[0, 0].axis('off')
|
|
|
|
|
|
|
|
|
avg_colors = results['color_data']['average_colors'].astype(np.uint8)
|
|
|
axes[0, 1].imshow(avg_colors)
|
|
|
axes[0, 1].set_title('Average Colors per Cell')
|
|
|
axes[0, 1].axis('off')
|
|
|
|
|
|
|
|
|
dom_colors = results['color_data']['dominant_colors'].astype(np.uint8)
|
|
|
axes[0, 2].imshow(dom_colors)
|
|
|
axes[0, 2].set_title('Dominant Colors per Cell')
|
|
|
axes[0, 2].axis('off')
|
|
|
|
|
|
|
|
|
categories = results['color_categories']
|
|
|
im_cat = axes[0, 3].imshow(categories, cmap='tab20')
|
|
|
axes[0, 3].set_title(f'Color Categories ({self.n_color_categories} classes)')
|
|
|
axes[0, 3].axis('off')
|
|
|
plt.colorbar(im_cat, ax=axes[0, 3])
|
|
|
|
|
|
|
|
|
brightness = results['color_data']['brightness']
|
|
|
im_bright = axes[1, 0].imshow(brightness, cmap='gray')
|
|
|
axes[1, 0].set_title('Brightness Values')
|
|
|
axes[1, 0].axis('off')
|
|
|
plt.colorbar(im_bright, ax=axes[1, 0])
|
|
|
|
|
|
|
|
|
saturation = results['color_data']['saturation']
|
|
|
im_sat = axes[1, 1].imshow(saturation, cmap='viridis')
|
|
|
axes[1, 1].set_title('Saturation Values')
|
|
|
axes[1, 1].axis('off')
|
|
|
plt.colorbar(im_sat, ax=axes[1, 1])
|
|
|
|
|
|
|
|
|
axes[1, 2].imshow(results['thresholds']['bright_mask'], cmap='gray')
|
|
|
axes[1, 2].set_title('Bright Areas (Threshold)')
|
|
|
axes[1, 2].axis('off')
|
|
|
|
|
|
|
|
|
axes[1, 3].imshow(results['thresholds']['saturated_mask'], cmap='gray')
|
|
|
axes[1, 3].set_title('Saturated Areas (Threshold)')
|
|
|
axes[1, 3].axis('off')
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
self._visualize_color_palette(results['category_colors'])
|
|
|
|
|
|
def _visualize_color_palette(self, category_colors: np.ndarray):
|
|
|
"""
|
|
|
Visualize the color category palette.
|
|
|
|
|
|
Args:
|
|
|
category_colors: Color palette (n_categories, channels)
|
|
|
"""
|
|
|
if category_colors is None:
|
|
|
return
|
|
|
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(12, 2))
|
|
|
|
|
|
|
|
|
colors = category_colors.copy()
|
|
|
if colors.max() > 1.0:
|
|
|
colors = colors / 255.0
|
|
|
|
|
|
|
|
|
palette = colors.reshape(1, -1, 3)
|
|
|
ax.imshow(palette, aspect='auto')
|
|
|
ax.set_xlim(0, len(colors))
|
|
|
ax.set_ylim(0, 1)
|
|
|
ax.set_xticks(range(len(colors)))
|
|
|
ax.set_xticklabels([f'Cat {i}' for i in range(len(colors))])
|
|
|
ax.set_title(f'Color Category Palette ({len(colors)} categories)')
|
|
|
ax.set_ylabel('Color Categories')
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.show()
|
|
|
|
|
|
def get_performance_stats(self, results: Dict) -> Dict:
|
|
|
"""
|
|
|
Calculate performance and analysis statistics.
|
|
|
|
|
|
Args:
|
|
|
results: Analysis results
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing statistics
|
|
|
"""
|
|
|
grid_shape = results['color_categories'].shape
|
|
|
total_cells = np.prod(grid_shape)
|
|
|
|
|
|
|
|
|
unique_categories = len(np.unique(results['color_categories']))
|
|
|
|
|
|
|
|
|
brightness = results['color_data']['brightness']
|
|
|
|
|
|
|
|
|
saturation = results['color_data']['saturation']
|
|
|
|
|
|
stats = {
|
|
|
'grid_size': f"{grid_shape[0]}x{grid_shape[1]}",
|
|
|
'total_cells': total_cells,
|
|
|
'unique_color_categories': unique_categories,
|
|
|
'category_utilization': unique_categories / self.n_color_categories,
|
|
|
'avg_brightness': np.mean(brightness),
|
|
|
'brightness_std': np.std(brightness),
|
|
|
'avg_saturation': np.mean(saturation),
|
|
|
'saturation_std': np.std(saturation),
|
|
|
'bright_cells_percent': np.mean(results['thresholds']['bright_mask']) * 100,
|
|
|
'saturated_cells_percent': np.mean(results['thresholds']['saturated_mask']) * 100
|
|
|
}
|
|
|
|
|
|
return stats
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""
|
|
|
Example usage of the ImageGridAnalyzer.
|
|
|
"""
|
|
|
|
|
|
def create_test_image():
|
|
|
|
|
|
img = np.zeros((256, 256, 3), dtype=np.uint8)
|
|
|
|
|
|
|
|
|
for i in range(256):
|
|
|
for j in range(256):
|
|
|
img[i, j, 0] = i
|
|
|
img[i, j, 1] = j
|
|
|
img[i, j, 2] = (i + j) % 255
|
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
analyzer = ImageGridAnalyzer(
|
|
|
grid_size=(32, 32),
|
|
|
classification_method=ColorClassificationMethod.DOMINANT_COLOR,
|
|
|
n_color_categories=16
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_image = cv2.imread('processed_quantized.jpg')
|
|
|
test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
|
|
|
|
|
|
print("Starting analysis...")
|
|
|
|
|
|
|
|
|
results = analyzer.analyze_image_complete(test_image)
|
|
|
|
|
|
|
|
|
stats = analyzer.get_performance_stats(results)
|
|
|
|
|
|
print("\n=== Analysis Statistics ===")
|
|
|
for key, value in stats.items():
|
|
|
print(f"{key}: {value}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return results, analyzer
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
results, analyzer = main() |