|
|
import numpy as np |
|
|
import cv2 |
|
|
from sklearn.cluster import KMeans |
|
|
from sklearn.preprocessing import LabelEncoder |
|
|
import matplotlib.pyplot as plt |
|
|
from typing import Tuple, List, Dict, Optional |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
import seaborn as sns |
|
|
|
|
|
class ColorClassificationMethod(Enum): |
|
|
"""Different methods for classifying cell colors.""" |
|
|
DOMINANT_COLOR = "dominant_color" |
|
|
AVERAGE_COLOR = "average_color" |
|
|
HISTOGRAM_BINS = "histogram_bins" |
|
|
HSV_QUANTIZATION = "hsv_quantization" |
|
|
|
|
|
@dataclass |
|
|
class GridCell: |
|
|
"""Represents a single grid cell with its properties.""" |
|
|
row: int |
|
|
col: int |
|
|
average_color: np.ndarray |
|
|
dominant_color: np.ndarray |
|
|
brightness: float |
|
|
saturation: float |
|
|
hue: float |
|
|
color_category: int |
|
|
pixel_data: np.ndarray |
|
|
|
|
|
class ImageGridAnalyzer: |
|
|
""" |
|
|
Analyzes images by dividing them into grids and classifying each cell's color properties. |
|
|
Uses vectorized NumPy operations for high performance. |
|
|
""" |
|
|
|
|
|
def __init__(self, grid_size: Tuple[int, int] = (32, 32), |
|
|
classification_method: ColorClassificationMethod = ColorClassificationMethod.DOMINANT_COLOR, |
|
|
n_color_categories: int = 16): |
|
|
""" |
|
|
Initialize the grid analyzer. |
|
|
|
|
|
Args: |
|
|
grid_size: (rows, cols) for the grid division |
|
|
classification_method: Method to classify cell colors |
|
|
n_color_categories: Number of color categories for classification |
|
|
""" |
|
|
self.grid_size = grid_size |
|
|
self.classification_method = classification_method |
|
|
self.n_color_categories = n_color_categories |
|
|
self.color_classifier = None |
|
|
self.category_colors = None |
|
|
|
|
|
def divide_image_into_grid(self, image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]: |
|
|
""" |
|
|
Divide image into a grid using vectorized operations. |
|
|
|
|
|
Args: |
|
|
image: Input image (H, W, C) |
|
|
|
|
|
Returns: |
|
|
Grid of cells (grid_rows, grid_cols, tile_height, tile_width, channels) |
|
|
Tuple of (tile_height, tile_width) |
|
|
""" |
|
|
h, w, c = image.shape |
|
|
grid_rows, grid_cols = self.grid_size |
|
|
|
|
|
|
|
|
tile_h = h // grid_rows |
|
|
tile_w = w // grid_cols |
|
|
|
|
|
|
|
|
adjusted_h = tile_h * grid_rows |
|
|
adjusted_w = tile_w * grid_cols |
|
|
image = image[:adjusted_h, :adjusted_w] |
|
|
|
|
|
|
|
|
|
|
|
grid = image.reshape(grid_rows, tile_h, grid_cols, tile_w, c) |
|
|
grid = grid.transpose(0, 2, 1, 3, 4) |
|
|
|
|
|
return grid, (tile_h, tile_w) |
|
|
|
|
|
def analyze_grid_colors_vectorized(self, grid: np.ndarray) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Analyze color properties of all grid cells using vectorized operations. |
|
|
|
|
|
Args: |
|
|
grid: Grid of cells (grid_rows, grid_cols, tile_h, tile_w, c) |
|
|
|
|
|
Returns: |
|
|
Dictionary containing vectorized analysis results |
|
|
""" |
|
|
grid_rows, grid_cols, tile_h, tile_w, c = grid.shape |
|
|
|
|
|
|
|
|
cells_flat = grid.reshape(grid_rows * grid_cols, tile_h * tile_w, c) |
|
|
|
|
|
|
|
|
average_colors = np.mean(cells_flat, axis=1) |
|
|
|
|
|
|
|
|
dominant_colors = self._calculate_dominant_colors_vectorized(cells_flat) |
|
|
|
|
|
|
|
|
hsv_averages = self._rgb_to_hsv_vectorized(average_colors) |
|
|
|
|
|
|
|
|
brightness = hsv_averages[:, 2] |
|
|
|
|
|
|
|
|
saturation = hsv_averages[:, 1] |
|
|
|
|
|
|
|
|
hue = hsv_averages[:, 0] |
|
|
|
|
|
|
|
|
results = { |
|
|
'average_colors': average_colors.reshape(grid_rows, grid_cols, c), |
|
|
'dominant_colors': dominant_colors.reshape(grid_rows, grid_cols, c), |
|
|
'brightness': brightness.reshape(grid_rows, grid_cols), |
|
|
'saturation': saturation.reshape(grid_rows, grid_cols), |
|
|
'hue': hue.reshape(grid_rows, grid_cols), |
|
|
'cells_data': grid |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
def _calculate_dominant_colors_vectorized(self, cells_flat: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Calculate dominant color for each cell using vectorized operations. |
|
|
|
|
|
Args: |
|
|
cells_flat: Flattened cells (total_cells, pixels_per_cell, channels) |
|
|
|
|
|
Returns: |
|
|
Dominant colors for all cells (total_cells, channels) |
|
|
""" |
|
|
import warnings |
|
|
|
|
|
total_cells, pixels_per_cell, c = cells_flat.shape |
|
|
dominant_colors = np.zeros((total_cells, c)) |
|
|
|
|
|
|
|
|
batch_size = 100 |
|
|
for i in range(0, total_cells, batch_size): |
|
|
end_idx = min(i + batch_size, total_cells) |
|
|
batch = cells_flat[i:end_idx] |
|
|
|
|
|
for j, cell_pixels in enumerate(batch): |
|
|
|
|
|
unique_pixels = np.unique(cell_pixels, axis=0) |
|
|
|
|
|
if len(unique_pixels) >= 3 and pixels_per_cell > 100: |
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
warnings.filterwarnings("ignore", message=".*ConvergenceWarning.*") |
|
|
|
|
|
kmeans = KMeans(n_clusters=min(3, len(unique_pixels)), |
|
|
random_state=42, n_init=5) |
|
|
labels = kmeans.fit_predict(cell_pixels) |
|
|
|
|
|
unique_labels, counts = np.unique(labels, return_counts=True) |
|
|
dominant_idx = unique_labels[np.argmax(counts)] |
|
|
dominant_colors[i + j] = kmeans.cluster_centers_[dominant_idx] |
|
|
elif len(unique_pixels) >= 2: |
|
|
|
|
|
unique_colors, counts = np.unique(cell_pixels, axis=0, return_counts=True) |
|
|
dominant_colors[i + j] = unique_colors[np.argmax(counts)] |
|
|
else: |
|
|
|
|
|
dominant_colors[i + j] = np.mean(cell_pixels, axis=0) |
|
|
|
|
|
return dominant_colors |
|
|
|
|
|
def _rgb_to_hsv_vectorized(self, rgb_colors: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Convert RGB colors to HSV using vectorized operations. |
|
|
|
|
|
Args: |
|
|
rgb_colors: RGB colors (N, 3) |
|
|
|
|
|
Returns: |
|
|
HSV colors (N, 3) |
|
|
""" |
|
|
|
|
|
rgb_normalized = rgb_colors / 255.0 |
|
|
|
|
|
|
|
|
dummy_img = rgb_normalized.reshape(-1, 1, 3).astype(np.float32) |
|
|
hsv_img = cv2.cvtColor(dummy_img, cv2.COLOR_RGB2HSV) |
|
|
hsv_colors = hsv_img.reshape(-1, 3) |
|
|
|
|
|
return hsv_colors |
|
|
|
|
|
def classify_colors(self, color_data: Dict[str, np.ndarray]) -> np.ndarray: |
|
|
""" |
|
|
Classify each grid cell into color categories. |
|
|
|
|
|
Args: |
|
|
color_data: Dictionary containing color analysis results |
|
|
|
|
|
Returns: |
|
|
Color categories for each grid cell (grid_rows, grid_cols) |
|
|
""" |
|
|
import warnings |
|
|
|
|
|
if self.classification_method == ColorClassificationMethod.AVERAGE_COLOR: |
|
|
features = color_data['average_colors'] |
|
|
elif self.classification_method == ColorClassificationMethod.DOMINANT_COLOR: |
|
|
features = color_data['dominant_colors'] |
|
|
elif self.classification_method == ColorClassificationMethod.HSV_QUANTIZATION: |
|
|
|
|
|
h = color_data['hue'] |
|
|
s = color_data['saturation'] |
|
|
v = color_data['brightness'] |
|
|
features = np.stack([h, s, v], axis=-1) |
|
|
else: |
|
|
features = color_data['average_colors'] |
|
|
|
|
|
|
|
|
grid_rows, grid_cols = features.shape[:2] |
|
|
features_flat = features.reshape(-1, features.shape[-1]) |
|
|
|
|
|
|
|
|
unique_features = np.unique(features_flat, axis=0) |
|
|
effective_clusters = min(self.n_color_categories, len(unique_features)) |
|
|
|
|
|
if effective_clusters < 2: |
|
|
|
|
|
print(f"Warning: Only {len(unique_features)} unique colors found. Using simple classification.") |
|
|
categories = np.zeros(len(features_flat), dtype=int) |
|
|
categories_grid = categories.reshape(grid_rows, grid_cols) |
|
|
self.category_colors = unique_features[:1] if len(unique_features) > 0 else np.array([[128, 128, 128]]) |
|
|
return categories_grid |
|
|
|
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
warnings.filterwarnings("ignore", message=".*ConvergenceWarning.*") |
|
|
|
|
|
self.color_classifier = KMeans(n_clusters=effective_clusters, |
|
|
random_state=42, n_init=10) |
|
|
categories = self.color_classifier.fit_predict(features_flat) |
|
|
|
|
|
|
|
|
self.category_colors = self.color_classifier.cluster_centers_ |
|
|
|
|
|
|
|
|
categories_grid = categories.reshape(grid_rows, grid_cols) |
|
|
|
|
|
return categories_grid |
|
|
|
|
|
def apply_thresholding(self, color_data: Dict[str, np.ndarray], |
|
|
brightness_threshold: float = 0.5, |
|
|
saturation_threshold: float = 0.3) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Apply thresholding to create binary masks for different criteria. |
|
|
|
|
|
Args: |
|
|
color_data: Color analysis results |
|
|
brightness_threshold: Threshold for bright/dark classification |
|
|
saturation_threshold: Threshold for saturated/desaturated classification |
|
|
|
|
|
Returns: |
|
|
Dictionary containing various threshold masks |
|
|
""" |
|
|
brightness = color_data['brightness'] |
|
|
saturation = color_data['saturation'] |
|
|
|
|
|
|
|
|
brightness_norm = brightness / 255.0 if brightness.max() > 1.0 else brightness |
|
|
saturation_norm = saturation / 255.0 if saturation.max() > 1.0 else saturation |
|
|
|
|
|
thresholds = { |
|
|
'bright_mask': brightness_norm > brightness_threshold, |
|
|
'dark_mask': brightness_norm <= brightness_threshold, |
|
|
'saturated_mask': saturation_norm > saturation_threshold, |
|
|
'desaturated_mask': saturation_norm <= saturation_threshold, |
|
|
'bright_saturated': (brightness_norm > brightness_threshold) & |
|
|
(saturation_norm > saturation_threshold), |
|
|
'dark_saturated': (brightness_norm <= brightness_threshold) & |
|
|
(saturation_norm > saturation_threshold) |
|
|
} |
|
|
|
|
|
return thresholds |
|
|
|
|
|
def analyze_image_complete(self, image: np.ndarray) -> Dict: |
|
|
""" |
|
|
Complete analysis pipeline for an image. |
|
|
|
|
|
Args: |
|
|
image: Input image (H, W, C) |
|
|
|
|
|
Returns: |
|
|
Complete analysis results |
|
|
""" |
|
|
print(f"Analyzing image with {self.grid_size[0]}x{self.grid_size[1]} grid...") |
|
|
|
|
|
|
|
|
grid, tile_size = self.divide_image_into_grid(image) |
|
|
print(f"Created grid with tile size: {tile_size}") |
|
|
|
|
|
|
|
|
color_data = self.analyze_grid_colors_vectorized(grid) |
|
|
print("Completed color analysis") |
|
|
|
|
|
|
|
|
color_categories = self.classify_colors(color_data) |
|
|
print(f"Classified into {self.n_color_categories} color categories") |
|
|
|
|
|
|
|
|
thresholds = self.apply_thresholding(color_data) |
|
|
print("Applied thresholding") |
|
|
|
|
|
|
|
|
results = { |
|
|
'grid': grid, |
|
|
'tile_size': tile_size, |
|
|
'color_data': color_data, |
|
|
'color_categories': color_categories, |
|
|
'thresholds': thresholds, |
|
|
'category_colors': self.category_colors |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
def visualize_analysis(self, results: Dict, original_image: np.ndarray): |
|
|
""" |
|
|
Create comprehensive visualizations of the analysis results. |
|
|
|
|
|
Args: |
|
|
results: Analysis results from analyze_image_complete |
|
|
original_image: Original input image |
|
|
""" |
|
|
fig, axes = plt.subplots(2, 4, figsize=(20, 10)) |
|
|
|
|
|
|
|
|
axes[0, 0].imshow(original_image) |
|
|
axes[0, 0].set_title('Original Image') |
|
|
axes[0, 0].axis('off') |
|
|
|
|
|
|
|
|
avg_colors = results['color_data']['average_colors'].astype(np.uint8) |
|
|
axes[0, 1].imshow(avg_colors) |
|
|
axes[0, 1].set_title('Average Colors per Cell') |
|
|
axes[0, 1].axis('off') |
|
|
|
|
|
|
|
|
dom_colors = results['color_data']['dominant_colors'].astype(np.uint8) |
|
|
axes[0, 2].imshow(dom_colors) |
|
|
axes[0, 2].set_title('Dominant Colors per Cell') |
|
|
axes[0, 2].axis('off') |
|
|
|
|
|
|
|
|
categories = results['color_categories'] |
|
|
im_cat = axes[0, 3].imshow(categories, cmap='tab20') |
|
|
axes[0, 3].set_title(f'Color Categories ({self.n_color_categories} classes)') |
|
|
axes[0, 3].axis('off') |
|
|
plt.colorbar(im_cat, ax=axes[0, 3]) |
|
|
|
|
|
|
|
|
brightness = results['color_data']['brightness'] |
|
|
im_bright = axes[1, 0].imshow(brightness, cmap='gray') |
|
|
axes[1, 0].set_title('Brightness Values') |
|
|
axes[1, 0].axis('off') |
|
|
plt.colorbar(im_bright, ax=axes[1, 0]) |
|
|
|
|
|
|
|
|
saturation = results['color_data']['saturation'] |
|
|
im_sat = axes[1, 1].imshow(saturation, cmap='viridis') |
|
|
axes[1, 1].set_title('Saturation Values') |
|
|
axes[1, 1].axis('off') |
|
|
plt.colorbar(im_sat, ax=axes[1, 1]) |
|
|
|
|
|
|
|
|
axes[1, 2].imshow(results['thresholds']['bright_mask'], cmap='gray') |
|
|
axes[1, 2].set_title('Bright Areas (Threshold)') |
|
|
axes[1, 2].axis('off') |
|
|
|
|
|
|
|
|
axes[1, 3].imshow(results['thresholds']['saturated_mask'], cmap='gray') |
|
|
axes[1, 3].set_title('Saturated Areas (Threshold)') |
|
|
axes[1, 3].axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
|
|
|
self._visualize_color_palette(results['category_colors']) |
|
|
|
|
|
def _visualize_color_palette(self, category_colors: np.ndarray): |
|
|
""" |
|
|
Visualize the color category palette. |
|
|
|
|
|
Args: |
|
|
category_colors: Color palette (n_categories, channels) |
|
|
""" |
|
|
if category_colors is None: |
|
|
return |
|
|
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(12, 2)) |
|
|
|
|
|
|
|
|
colors = category_colors.copy() |
|
|
if colors.max() > 1.0: |
|
|
colors = colors / 255.0 |
|
|
|
|
|
|
|
|
palette = colors.reshape(1, -1, 3) |
|
|
ax.imshow(palette, aspect='auto') |
|
|
ax.set_xlim(0, len(colors)) |
|
|
ax.set_ylim(0, 1) |
|
|
ax.set_xticks(range(len(colors))) |
|
|
ax.set_xticklabels([f'Cat {i}' for i in range(len(colors))]) |
|
|
ax.set_title(f'Color Category Palette ({len(colors)} categories)') |
|
|
ax.set_ylabel('Color Categories') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
def get_performance_stats(self, results: Dict) -> Dict: |
|
|
""" |
|
|
Calculate performance and analysis statistics. |
|
|
|
|
|
Args: |
|
|
results: Analysis results |
|
|
|
|
|
Returns: |
|
|
Dictionary containing statistics |
|
|
""" |
|
|
grid_shape = results['color_categories'].shape |
|
|
total_cells = np.prod(grid_shape) |
|
|
|
|
|
|
|
|
unique_categories = len(np.unique(results['color_categories'])) |
|
|
|
|
|
|
|
|
brightness = results['color_data']['brightness'] |
|
|
|
|
|
|
|
|
saturation = results['color_data']['saturation'] |
|
|
|
|
|
stats = { |
|
|
'grid_size': f"{grid_shape[0]}x{grid_shape[1]}", |
|
|
'total_cells': total_cells, |
|
|
'unique_color_categories': unique_categories, |
|
|
'category_utilization': unique_categories / self.n_color_categories, |
|
|
'avg_brightness': np.mean(brightness), |
|
|
'brightness_std': np.std(brightness), |
|
|
'avg_saturation': np.mean(saturation), |
|
|
'saturation_std': np.std(saturation), |
|
|
'bright_cells_percent': np.mean(results['thresholds']['bright_mask']) * 100, |
|
|
'saturated_cells_percent': np.mean(results['thresholds']['saturated_mask']) * 100 |
|
|
} |
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Example usage of the ImageGridAnalyzer. |
|
|
""" |
|
|
|
|
|
def create_test_image(): |
|
|
|
|
|
img = np.zeros((256, 256, 3), dtype=np.uint8) |
|
|
|
|
|
|
|
|
for i in range(256): |
|
|
for j in range(256): |
|
|
img[i, j, 0] = i |
|
|
img[i, j, 1] = j |
|
|
img[i, j, 2] = (i + j) % 255 |
|
|
|
|
|
return img |
|
|
|
|
|
|
|
|
analyzer = ImageGridAnalyzer( |
|
|
grid_size=(32, 32), |
|
|
classification_method=ColorClassificationMethod.DOMINANT_COLOR, |
|
|
n_color_categories=16 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_image = cv2.imread('processed_quantized.jpg') |
|
|
test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
print("Starting analysis...") |
|
|
|
|
|
|
|
|
results = analyzer.analyze_image_complete(test_image) |
|
|
|
|
|
|
|
|
stats = analyzer.get_performance_stats(results) |
|
|
|
|
|
print("\n=== Analysis Statistics ===") |
|
|
for key, value in stats.items(): |
|
|
print(f"{key}: {value}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return results, analyzer |
|
|
|
|
|
if __name__ == "__main__": |
|
|
results, analyzer = main() |