| | """Interactive color palette extractor and demo. |
| | |
| | This script loads an image, extracts a palette with a configurable |
| | number of colors, and showcases the palette inside a Gradio interface |
| | using multiple visual styles (gradient bar and card layout). |
| | """ |
| | from __future__ import annotations |
| |
|
| | import argparse |
| | from dataclasses import dataclass |
| | from typing import List, Sequence, Tuple |
| |
|
| | import numpy as np |
| | from PIL import Image |
| |
|
| | try: |
| | import gradio as gr |
| | except ImportError as exc: |
| | raise ImportError( |
| | "Gradio is required to run the interactive demo." |
| | ) from exc |
| |
|
| | try: |
| | import plotly.graph_objects as go |
| | except ImportError as exc: |
| | raise ImportError( |
| | "Plotly is required for the 3D scatter visualization. Install it via `pip install plotly`." |
| | ) from exc |
| |
|
| |
|
| | @dataclass |
| | class PaletteColor: |
| | """Container for palette metadata.""" |
| |
|
| | rgb: Tuple[int, int, int] |
| | percentage: float |
| |
|
| | @property |
| | def hex(self) -> str: |
| | return "#" + "".join(f"{channel:02X}" for channel in self.rgb) |
| |
|
| |
|
| | @dataclass |
| | class PaletteResult: |
| | """Aggregated palette data and clustering artifacts.""" |
| |
|
| | colors: List[PaletteColor] |
| | samples: np.ndarray |
| | labels: np.ndarray |
| | centroids: np.ndarray |
| |
|
| |
|
| | def _prepare_pixels(image: Image.Image, max_sample: int = 5000) -> np.ndarray: |
| | """Convert an image into a 2D array of pixels and optionally subsample.""" |
| |
|
| | if image.mode not in {"RGB", "RGBA"}: |
| | image = image.convert("RGB") |
| |
|
| | pixels = np.array(image) |
| | if pixels.ndim == 3 and pixels.shape[2] == 4: |
| | pixels = pixels[:, :, :3] |
| |
|
| | flat_pixels = pixels.reshape(-1, 3).astype(np.float32) |
| |
|
| | if len(flat_pixels) > max_sample: |
| | rng = np.random.default_rng(42) |
| | indices = rng.choice(len(flat_pixels), size=max_sample, replace=False) |
| | flat_pixels = flat_pixels[indices] |
| |
|
| | return flat_pixels |
| |
|
| |
|
| | def _kmeans( |
| | pixels: np.ndarray, |
| | num_colors: int, |
| | *, |
| | max_iter: int = 25, |
| | tol: float = 1e-2, |
| | seed: int | None = None, |
| | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| | """Simple k-means clustering implementation for RGB pixels.""" |
| |
|
| | if len(pixels) == 0: |
| | raise ValueError("No pixels available for clustering") |
| |
|
| | num_colors = max(1, min(num_colors, len(pixels))) |
| |
|
| | rng = np.random.default_rng(seed) |
| | initial_indices = rng.choice(len(pixels), size=num_colors, replace=False) |
| | centroids = pixels[initial_indices] |
| |
|
| | for _ in range(max_iter): |
| | distances = np.linalg.norm(pixels[:, None, :] - centroids[None, :, :], axis=2) |
| | labels = np.argmin(distances, axis=1) |
| |
|
| | new_centroids = np.vstack( |
| | [pixels[labels == idx].mean(axis=0) if np.any(labels == idx) else centroids[idx] |
| | for idx in range(num_colors)] |
| | ) |
| |
|
| | shift = np.linalg.norm(new_centroids - centroids) |
| | centroids = new_centroids |
| |
|
| | if shift < tol: |
| | break |
| |
|
| | counts = np.array([(labels == idx).sum() for idx in range(num_colors)], dtype=np.int32) |
| | order = np.argsort(counts)[::-1] |
| | remap = np.empty_like(order) |
| | remap[order] = np.arange(num_colors) |
| | remapped_labels = remap[labels] |
| | return centroids[order], counts[order], remapped_labels |
| |
|
| |
|
| | def extract_palette( |
| | image_source: Image.Image | np.ndarray | str, |
| | num_colors: int, |
| | seed: int | None = None, |
| | ) -> PaletteResult: |
| | """Extract a color palette with num_colors entries from the image.""" |
| |
|
| | if isinstance(image_source, Image.Image): |
| | image = image_source |
| | elif isinstance(image_source, np.ndarray): |
| | image = Image.fromarray(image_source.astype(np.uint8)) |
| | elif isinstance(image_source, str): |
| | image = Image.open(image_source) |
| | else: |
| | raise TypeError("Unsupported image source type") |
| |
|
| | pixels = _prepare_pixels(image) |
| | centroids, counts, labels = _kmeans(pixels, num_colors, seed=seed) |
| |
|
| | total = counts.sum() |
| | palette = [] |
| | for centroid, count in zip(centroids, counts): |
| | rounded = np.clip(np.round(centroid), 0, 255).astype(np.uint8) |
| | palette.append(PaletteColor(tuple(int(channel) for channel in rounded), count / total)) |
| |
|
| | return PaletteResult(colors=palette, samples=pixels, labels=labels, centroids=centroids) |
| |
|
| |
|
| | def _gradient_html(palette: Sequence[PaletteColor]) -> str: |
| | """Create a CSS gradient bar to display the palette smoothly.""" |
| |
|
| | if not palette: |
| | return "<div>暂无调色数据</div>" |
| |
|
| | stops = [] |
| | total = len(palette) - 1 or 1 |
| | for idx, color in enumerate(palette): |
| | percent = (idx / total) * 100 |
| | stops.append(f"{color.hex} {percent:.2f}%") |
| |
|
| | gradient = ", ".join(stops) |
| | return f"<div style=\"height: 48px; border-radius: 8px; border: 1px solid #d1d5db; background: linear-gradient(90deg, {gradient});\"></div>" |
| |
|
| |
|
| | def _relative_luminance(rgb: Tuple[int, int, int]) -> float: |
| | """Calculate the WCAG relative luminance for an RGB tuple.""" |
| |
|
| | def _channel_linear(value: int) -> float: |
| | srgb = value / 255 |
| | return srgb / 12.92 if srgb <= 0.03928 else ((srgb + 0.055) / 1.055) ** 2.4 |
| |
|
| | r, g, b = rgb |
| | return 0.2126 * _channel_linear(r) + 0.7152 * _channel_linear(g) + 0.0722 * _channel_linear(b) |
| |
|
| |
|
| | def _typography_html(palette: Sequence[PaletteColor]) -> str: |
| | """Show typography samples on colored backgrounds.""" |
| |
|
| | if not palette: |
| | return "<div>暂无调色数据</div>" |
| |
|
| | previews = [] |
| | phrases = [ |
| | "用色彩描绘心境", |
| | "颜色是灵魂的触觉", |
| | "色彩即语言" |
| | ] |
| | for idx, color in enumerate(palette[:6]): |
| | luminance = _relative_luminance(color.rgb) |
| | text_hex = "#101321" if luminance > 0.55 else "#F5F7FF" |
| | secondary_hex = "#2E3143" if luminance > 0.55 else "#D8E2FF" |
| | phrase = phrases[idx % len(phrases)] |
| | button = ( |
| | f"<button type='button' style=\"float:right;margin-bottom:0.6rem;padding:0.3rem 0.6rem;" |
| | "border:1px solid rgba(17,24,39,0.2);border-radius:6px;background-color:rgba(255,255,255,0.8);" |
| | "font-size:0.75rem;cursor:pointer;\"" |
| | f" onclick=\"navigator.clipboard.writeText('{color.hex}');" |
| | "this.textContent='已复制';setTimeout(()=>this.textContent='复制 HEX',1200);\">复制 HEX</button>" |
| | ) |
| | previews.append( |
| | "<div style=\"border:1px solid #d1d5db;border-radius:8px;padding:0.8rem;" |
| | f"background:{color.hex};color:{text_hex};margin-bottom:0.6rem;position:relative;overflow:hidden;\">" |
| | f" {button}" |
| | f" <div style=\"font-weight:600;\">{phrase}</div>" |
| | f" <div style=\"font-size:0.85rem;color:{secondary_hex};margin-top:0.2rem;\">{color.hex} · {color.rgb}</div>" |
| | " <div style=\"margin-top:0.6rem;font-size:0.85rem;\">" |
| | " 在当前背景色上预览文字对比度。" |
| | " </div>" |
| | "</div>" |
| | ) |
| |
|
| | return "".join(previews) |
| |
|
| |
|
| | def _scatter_figure(result: PaletteResult) -> go.Figure: |
| | """Build a 3D scatter figure of sampled pixels in RGB space.""" |
| |
|
| | fig = go.Figure() |
| |
|
| | samples = result.samples |
| | labels = result.labels |
| | for idx, color in enumerate(result.colors): |
| | cluster_points = samples[labels == idx] |
| | if cluster_points.size == 0: |
| | continue |
| | fig.add_trace( |
| | go.Scatter3d( |
| | x=cluster_points[:, 0], |
| | y=cluster_points[:, 1], |
| | z=cluster_points[:, 2], |
| | mode="markers", |
| | marker=dict(size=3, color=color.hex, opacity=0.35), |
| | name=f"Cluster {idx + 1}", |
| | hovertemplate="R:%{x:.0f}<br>G:%{y:.0f}<br>B:%{z:.0f}<extra>{color.hex}</extra>", |
| | ) |
| | ) |
| |
|
| | fig.add_trace( |
| | go.Scatter3d( |
| | x=result.centroids[:, 0], |
| | y=result.centroids[:, 1], |
| | z=result.centroids[:, 2], |
| | mode="markers", |
| | marker=dict( |
| | size=9, |
| | color=[color.hex for color in result.colors], |
| | symbol="diamond", |
| | line=dict(width=1.5, color="#111111"), |
| | ), |
| | name="Centroids", |
| | hovertemplate="R:%{x:.0f}<br>G:%{y:.0f}<br>B:%{z:.0f}<extra>Centroid</extra>", |
| | ) |
| | ) |
| |
|
| | fig.update_layout( |
| | scene=dict( |
| | xaxis=dict(title="Red", range=[0, 255]), |
| | yaxis=dict(title="Green", range=[0, 255]), |
| | zaxis=dict(title="Blue", range=[0, 255]), |
| | aspectmode="cube", |
| | ), |
| | legend=dict(orientation="h", x=0.0, y=1.02), |
| | margin=dict(l=0, r=0, t=30, b=0), |
| | ) |
| |
|
| | return fig |
| |
|
| |
|
| | def analyze_image( |
| | image: Image.Image | np.ndarray, |
| | num_colors: int, |
| | seed: int, |
| | ) -> Tuple[str, str, go.Figure | None, List[dict]]: |
| | """Processing function used by the Gradio interface.""" |
| |
|
| | if image is None: |
| | empty = "<div>请上传图片</div>" |
| | return empty, empty, None, [] |
| |
|
| | result = extract_palette(image, num_colors=num_colors, seed=seed) |
| |
|
| | json_payload = [ |
| | { |
| | "hex": color.hex, |
| | "rgb": color.rgb, |
| | "percentage": round(color.percentage, 4), |
| | } |
| | for color in result.colors |
| | ] |
| |
|
| | return ( |
| | _gradient_html(result.colors), |
| | _typography_html(result.colors), |
| | _scatter_figure(result), |
| | json_payload, |
| | ) |
| |
|
| |
|
| | CUSTOM_CSS = None |
| |
|
| |
|
| | def build_demo(default_num_colors: int = 5, default_seed: int = 42) -> "gr.Blocks": |
| | """Create the Gradio Blocks interface.""" |
| |
|
| | default_num_colors = int(max(2, min(12, default_num_colors))) |
| | default_seed = int(max(0, min(10_000, default_seed))) |
| |
|
| | with gr.Blocks(css=CUSTOM_CSS, title="Palette Explorer") as demo: |
| | gr.Markdown( |
| | """ |
| | # 🎨 Palette Explorer |
| | |
| | 上传图片,提取主要色调,并以多种方式查看调色板数据。 |
| | """ |
| | ) |
| |
|
| | with gr.Row(equal_height=True): |
| | with gr.Column(scale=6, min_width=420): |
| | input_image = gr.Image(label="上传图片", type="pil", height=420) |
| | with gr.Row(): |
| | num_colors = gr.Slider( |
| | minimum=2, |
| | maximum=12, |
| | value=default_num_colors, |
| | step=1, |
| | label="调色板颜色数量", |
| | ) |
| | seed = gr.Slider( |
| | minimum=0, |
| | maximum=10_000, |
| | value=default_seed, |
| | step=1, |
| | label="随机种子", |
| | ) |
| |
|
| | run_btn = gr.Button("生成调色板") |
| | gr.Markdown("_提示:尝试调整颜色数量和随机种子,以探索不同的聚类结果。_") |
| |
|
| | with gr.Column(scale=6, min_width=420): |
| | gradient_view = gr.HTML(label="渐变光带") |
| | typography_view = gr.HTML(label="排版预览") |
| | scatter_view = gr.Plot(label="RGB 三维散点") |
| | data_view = gr.JSON(label="调色板数据") |
| |
|
| | run_btn.click( |
| | fn=analyze_image, |
| | inputs=[input_image, num_colors, seed], |
| | outputs=[ |
| | gradient_view, |
| | typography_view, |
| | scatter_view, |
| | data_view, |
| | ], |
| | ) |
| |
|
| | return demo |
| |
|
| |
|
| | def launch_demo( |
| | default_num_colors: int = 5, |
| | default_seed: int = 42, |
| | *, |
| | share: bool = False, |
| | server_name: str | None = None, |
| | server_port: int | None = None, |
| | ) -> None: |
| | """Launch the Gradio demo with optional configuration overrides.""" |
| |
|
| | demo = build_demo(default_num_colors=default_num_colors, default_seed=default_seed) |
| | demo.launch(share=share, server_name=server_name, server_port=server_port) |
| |
|
| |
|
| |
|
| | def run_cli(image: str, num_colors: int, seed: int) -> None: |
| | """Command-line execution path printing palette information.""" |
| |
|
| | result = extract_palette(image, num_colors=num_colors, seed=seed) |
| |
|
| | print("Hex RGB Percentage") |
| | for color in result.colors: |
| | print(f"{color.hex} {color.rgb} {color.percentage * 100:.2f}%") |
| |
|
| |
|
| |
|
| | def main() -> None: |
| | """Entry point that supports both CLI and UI usage.""" |
| |
|
| | parser = argparse.ArgumentParser(description="Extract a color palette or launch the Gradio UI") |
| | parser.add_argument("image", nargs="?", help="Path to the input image for CLI mode") |
| | parser.add_argument("-n", "--num-colors", type=int, default=5, help="Number of colors in the palette or default for UI") |
| | parser.add_argument("-s", "--seed", type=int, default=42, help="Random seed for k-means initialisation or default for UI") |
| | parser.add_argument("--ui", action="store_true", help="Launch the Gradio interface instead of the CLI output") |
| | parser.add_argument("--share", action="store_true", help="Share the Gradio demo publicly") |
| | parser.add_argument("--server-name", type=str, default=None, help="Hostname for Gradio server") |
| | parser.add_argument("--server-port", type=int, default=None, help="Port for Gradio server") |
| |
|
| | args = parser.parse_args() |
| |
|
| | if args.ui or args.image is None: |
| | launch_demo( |
| | default_num_colors=args.num_colors, |
| | default_seed=args.seed, |
| | share=args.share, |
| | server_name=args.server_name, |
| | server_port=args.server_port, |
| | ) |
| | else: |
| | run_cli(args.image, num_colors=args.num_colors, seed=args.seed) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|