import gradio as gr import json import numpy as np import re import tarfile import PIL.Image as PImage from os import listdir, path, remove from sklearn.cluster import KMeans from urllib import request def download_extract(): url = "https://github.com/PSAM-5020-2025S-A/5020-utils/releases/latest/download/flowers.tar.gz" target_path = "flowers.tar.gz" with request.urlopen(request.Request(url), timeout=15.0) as response: if response.status == 200: with open(target_path, "wb") as f: f.write(response.read()) tar = tarfile.open(target_path, "r:gz") tar.extractall() tar.close() remove("flowers.tar.gz") # Posterize image and get representative colors def top_colors(fpath, n_clusters=8, n_colors=4): pimg = PImage.open(fpath).convert("RGB") pimg_pxs = list(pimg.getdata()) posterizer = KMeans(n_clusters=n_clusters) px_clusters = posterizer.fit_predict(pimg_pxs) cluster_colors = posterizer.cluster_centers_ _, ccounts = np.unique(px_clusters, return_counts=True) ccounts_order = np.argsort(-ccounts) ccolors_sorted = [[round(rgb) for rgb in cluster_colors[idx]] for idx in ccounts_order] return ccolors_sorted[:n_colors] # Cluster all images def get_top_colors(flower_image_dir): flower_files = sorted([f for f in listdir(flower_image_dir) if f.endswith(".png")]) file_colors = [] for fname in flower_files: file_colors.append({ "filename": fname, "colors": top_colors(f"{flower_image_dir}/{fname}", n_clusters=8, n_colors=4) }) return file_colors # Euclidean distance between 2 RGB color tuples def color_distance(c0, c1): return ((c0[0] - c1[0])**2 + (c0[1] - c1[1])**2 + (c0[2] - c1[2])**2) ** 0.5 # Function that returns minimum distance between a reference color and colors from a list def min_color_distance(ref_color, color_list): c_dists = [color_distance(ref_color, c) for c in color_list] return min(c_dists) # Turns a css color string in the form `#12AB56` or # `rgb(18, 171, 87)` or # `rgba(18, 171, 87, 1)` # into an RGB list [18, 171, 87] # # Needed because of this bug: https://github.com/gradio-app/gradio/issues/12071 def css_to_rgb(css_str): if css_str[0] == "#": return [int(css_str[i:i+2], 16) for i in range(1,6,2)] COLOR_PATTERN = r"([^(]+)\(([0-9.]+), ?([0-9.]+%?), ?([0-9.]+%?)(, ?([0-9.]+))?\)" match = re.match(COLOR_PATTERN, css_str) if not match: return [0,0,0] if "rgb" in match.group(1): return [int(float(match.group(i))) for i in range(2,5)] if "hsl" in match.group(1): print("hsl not supported") return [0,0,0] def order_by_color(center_color_str): center_color = css_to_rgb(center_color_str) # Function that returns how close an image is to a given color def by_color_dist(A): return min_color_distance(center_color, A["colors"]) file_colors_sorted = sorted(FILE_COLORS, key=by_color_dist) files_sorted = [A["filename"] for A in file_colors_sorted] file_order = { "color": center_color, "files": files_sorted } return json.dumps(file_order) my_inputs = [ gr.ColorPicker(value="#ffdf00", label="center_color", interactive=True) ] my_outputs = [ gr.JSON(show_label=False, show_indices=False, height=200, container=False) ] my_examples = [ ["#FFFFFF"], ["#FFD700"], ["#7814BE"] ] def setup(): global FILE_COLORS FLOWER_IMG_DIR = "./data/image/flowers" if not path.isdir(FLOWER_IMG_DIR): download_extract() FILE_COLORS = get_top_colors(FLOWER_IMG_DIR) setup() with gr.Blocks() as demo: gr.Interface( fn=order_by_color, inputs=my_inputs, outputs=my_outputs, flagging_mode="never", cache_examples=True, examples=my_examples, fill_width=True ) if __name__ == "__main__": demo.launch(ssr_mode=False)