cvTools / app.py
lalamax3d's picture
local dev and multitab base
4cf346c
import gradio as gr
import numpy as np
import cv2
from sklearn.cluster import KMeans
from collections import Counter
import os
def extract_palette(image, num_colors=5):
# Convert to LAB color space for better color perception
image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
# Reshape the image into a list of pixels
pixels = image.reshape(-1, 3)
# Sample pixels while maintaining color distribution
num_pixels = len(pixels)
if num_pixels > 20000: # Increased sample size
# Count frequency of colors
pixel_tuples = [tuple(p) for p in pixels]
color_counts = Counter(pixel_tuples)
unique_colors = np.array(list(color_counts.keys()))
weights = np.array(list(color_counts.values()))
weights = weights / weights.sum()
# Sample based on color frequency
indices = np.random.choice(len(unique_colors), 20000, p=weights, replace=True)
pixels = unique_colors[indices]
# Apply KMeans clustering with better parameters
kmeans = KMeans(n_clusters=num_colors, n_init='auto', random_state=42)
kmeans.fit(pixels)
# Convert centers back to RGB
centers = kmeans.cluster_centers_.astype(np.uint8)
centers = centers.reshape(1, -1, 3)
centers = cv2.cvtColor(centers, cv2.COLOR_LAB2RGB)
centers = centers.reshape(-1, 3)
# Sort colors by frequency
labels = kmeans.labels_
counts = np.bincount(labels)
colors = centers[np.argsort(-counts)]
return colors
def visualize_palette(image, num_colors):
import tempfile
import os
# Extract palette
palette = extract_palette(image, num_colors)
# Create a color palette image using OpenCV
cell_width = 100
cell_height = 100
palette_image = np.zeros((cell_height, cell_width * len(palette), 3), dtype=np.uint8)
# Fill each cell with a color
for i, color in enumerate(palette):
start_x = i * cell_width
end_x = start_x + cell_width
palette_image[:, start_x:end_x] = color
# Use tempfile to create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
output_path = tmp_file.name
cv2.imwrite(output_path, cv2.cvtColor(palette_image, cv2.COLOR_RGB2BGR))
return output_path
# Create the Gradio Blocks app with tabs
def create_demo():
with gr.Blocks(title="CV Tools", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Computer Vision Tools")
with gr.Tabs():
# Tab 1: Color Palette Extractor
with gr.Tab("Color Palette"):
gr.Markdown("## Extract dominant colors from your image")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy", label="Input Image")
num_colors = gr.Slider(
minimum=2,
maximum=50,
step=1,
value=5,
label="Number of Colors"
)
extract_btn = gr.Button("Extract Palette", variant="primary")
with gr.Column():
output_palette = gr.Image(
type="filepath",
label="Color Palette"
)
extract_btn.click(
fn=visualize_palette,
inputs=[input_image, num_colors],
outputs=output_palette
)
# Tab 2: Image Filters (placeholder for your next feature)
with gr.Tab("Image Filters"):
gr.Markdown("## Image Filters (Coming Soon)")
# Add your image filters interface here
# Tab 3: Another CV feature (placeholder)
with gr.Tab("More Features"):
gr.Markdown("## More Features Coming Soon")
# Add more CV tools here
return demo
# Launch the app with auto-reloading in development
if __name__ == "__main__":
demo = create_demo()
demo.queue() # Enable queuing for better handling of simultaneous requests
demo.launch(
share=True, # Enable sharing
server_port=7860, # Specify port
server_name="0.0.0.0" # Listen on all interfaces
)