|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from sklearn.cluster import KMeans |
|
|
from collections import Counter |
|
|
import os |
|
|
|
|
|
def extract_palette(image, num_colors=5): |
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) |
|
|
|
|
|
|
|
|
pixels = image.reshape(-1, 3) |
|
|
|
|
|
|
|
|
num_pixels = len(pixels) |
|
|
if num_pixels > 20000: |
|
|
|
|
|
pixel_tuples = [tuple(p) for p in pixels] |
|
|
color_counts = Counter(pixel_tuples) |
|
|
unique_colors = np.array(list(color_counts.keys())) |
|
|
weights = np.array(list(color_counts.values())) |
|
|
weights = weights / weights.sum() |
|
|
|
|
|
|
|
|
indices = np.random.choice(len(unique_colors), 20000, p=weights, replace=True) |
|
|
pixels = unique_colors[indices] |
|
|
|
|
|
|
|
|
kmeans = KMeans(n_clusters=num_colors, n_init='auto', random_state=42) |
|
|
kmeans.fit(pixels) |
|
|
|
|
|
|
|
|
centers = kmeans.cluster_centers_.astype(np.uint8) |
|
|
centers = centers.reshape(1, -1, 3) |
|
|
centers = cv2.cvtColor(centers, cv2.COLOR_LAB2RGB) |
|
|
centers = centers.reshape(-1, 3) |
|
|
|
|
|
|
|
|
labels = kmeans.labels_ |
|
|
counts = np.bincount(labels) |
|
|
colors = centers[np.argsort(-counts)] |
|
|
|
|
|
return colors |
|
|
|
|
|
def visualize_palette(image, num_colors): |
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
|
|
|
palette = extract_palette(image, num_colors) |
|
|
|
|
|
|
|
|
cell_width = 100 |
|
|
cell_height = 100 |
|
|
palette_image = np.zeros((cell_height, cell_width * len(palette), 3), dtype=np.uint8) |
|
|
|
|
|
|
|
|
for i, color in enumerate(palette): |
|
|
start_x = i * cell_width |
|
|
end_x = start_x + cell_width |
|
|
palette_image[:, start_x:end_x] = color |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file: |
|
|
output_path = tmp_file.name |
|
|
cv2.imwrite(output_path, cv2.cvtColor(palette_image, cv2.COLOR_RGB2BGR)) |
|
|
|
|
|
return output_path |
|
|
|
|
|
|
|
|
def create_demo(): |
|
|
with gr.Blocks(title="CV Tools", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# Computer Vision Tools") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("Color Palette"): |
|
|
gr.Markdown("## Extract dominant colors from your image") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image(type="numpy", label="Input Image") |
|
|
num_colors = gr.Slider( |
|
|
minimum=2, |
|
|
maximum=50, |
|
|
step=1, |
|
|
value=5, |
|
|
label="Number of Colors" |
|
|
) |
|
|
extract_btn = gr.Button("Extract Palette", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_palette = gr.Image( |
|
|
type="filepath", |
|
|
label="Color Palette" |
|
|
) |
|
|
|
|
|
extract_btn.click( |
|
|
fn=visualize_palette, |
|
|
inputs=[input_image, num_colors], |
|
|
outputs=output_palette |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("Image Filters"): |
|
|
gr.Markdown("## Image Filters (Coming Soon)") |
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("More Features"): |
|
|
gr.Markdown("## More Features Coming Soon") |
|
|
|
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_demo() |
|
|
demo.queue() |
|
|
demo.launch( |
|
|
share=True, |
|
|
server_port=7860, |
|
|
server_name="0.0.0.0" |
|
|
) |