Spaces:
Running
Running
| # %% | |
| import cv2 | |
| from sklearn.cluster import KMeans | |
| from PIL import Image | |
| import numpy as np | |
| import gradio.components as gc | |
| import gradio as gr | |
| def pixart( | |
| i, | |
| block_size=4, | |
| n_clusters=5, | |
| hsv_weights=[0, 0, 1], | |
| local_contrast_blur_radius=51, # has to be odd | |
| upscale=True, | |
| seed=None, | |
| output_scaling=1, | |
| dither_amount=15 | |
| ): | |
| w, h = i.size | |
| dw = w//block_size | |
| dh = h//block_size | |
| # always resize with NEAREST to keep the original colors | |
| i = i.resize((dw, dh), Image.Resampling.NEAREST) | |
| ai = np.array(i) | |
| if seed is None: | |
| # seed = np.random.randint(0, 2**32 - 1) | |
| seed = np.random.randint(0, 2**16 - 1) | |
| km = KMeans(n_clusters=n_clusters, random_state=seed) | |
| hsv = cv2.cvtColor(ai, cv2.COLOR_RGB2HSV) | |
| bhsv = cv2.GaussianBlur( | |
| hsv, | |
| (local_contrast_blur_radius, local_contrast_blur_radius), | |
| 0, | |
| borderType=cv2.BORDER_REPLICATE | |
| ) | |
| hsv32 = hsv.astype(np.float32) | |
| km.fit( | |
| hsv32.reshape(-1, hsv32.shape[-1]), | |
| # (sharp-blurred) gives large values if a pixel stands out from its surroundings | |
| # raise to the power of 4 to make the difference more pronounced. | |
| # this preserves rare specks of color by increasing the probability of them getting their own cluster | |
| sample_weight=( | |
| np.linalg.norm((hsv32 - bhsv), axis=-1).reshape(-1) | |
| ** 4 | |
| ) | |
| ) | |
| label_grid = km.labels_.reshape(hsv32.shape[:2]) | |
| centers = km.cluster_centers_ # hsv values | |
| def pick_representative_pixel(cluster): | |
| '''pick the representative pixel for a cluster''' | |
| most_sat_color = (hsv[label_grid == cluster] @ | |
| np.array(hsv_weights)).argmax() | |
| return hsv[label_grid == cluster][most_sat_color] | |
| cluster_colors = np.array([ | |
| pick_representative_pixel(c) | |
| for c in range(centers.shape[0])]) | |
| if dither_amount == 0: | |
| # assign each pixel the color of its cluster | |
| ki = cluster_colors[label_grid] | |
| else: | |
| # add noise to the colors before selecting the nearest color, this acts as a dithering effect | |
| noised_colors = hsv32 + np.random.normal(0, dither_amount, hsv.shape) | |
| noised_colors = np.clip(noised_colors, 0, 255) | |
| flattened = noised_colors.reshape(-1, 3) | |
| # use the dot product to find the closest cluster (could also try euclidean distance) | |
| closest_clusters = np.argmax(flattened @ centers.T,axis=1) | |
| closest_clusters_eucledian = np.argmin(np.linalg.norm(centers - flattened[:, None], axis=-1), axis=1) | |
| label_grid = closest_clusters_eucledian.reshape(hsv32.shape[:2]) | |
| ki = cluster_colors[label_grid] | |
| rgb = cv2.cvtColor(ki.astype(np.uint8), cv2.COLOR_HSV2RGB) | |
| i = Image.fromarray(rgb) | |
| if upscale: | |
| i = i.resize((w, h), Image.Resampling.NEAREST) | |
| if output_scaling != 1: | |
| i = i.resize( | |
| (w*output_scaling, h*output_scaling), Image.Resampling.NEAREST) | |
| return i, seed | |
| def query( | |
| i: Image.Image, | |
| block_size: str, | |
| n_clusters, # =5, | |
| hsv_weights, # ='0,0,1' | |
| local_contrast_blur_radius, # =51 has to be odd | |
| seed, # =42, | |
| output_scaling, | |
| dither_amount | |
| ): | |
| bs = float(block_size) | |
| w, h = i.size | |
| if bs < 1: | |
| blsz = int(bs * min(w, h)) | |
| else: | |
| blsz = int(bs) | |
| hw = [float(w) for w in hsv_weights.split(',')] | |
| pxart, usedseed = pixart( | |
| i, | |
| block_size=blsz, | |
| n_clusters=n_clusters, | |
| hsv_weights=hw, | |
| local_contrast_blur_radius=local_contrast_blur_radius, | |
| upscale=True, | |
| seed=int(seed) if seed != '' else None, | |
| output_scaling=output_scaling, | |
| dither_amount=dither_amount | |
| ) | |
| if n_clusters <= 256: | |
| pxart = pxart.convert('P', palette=Image.Palette.ADAPTIVE, colors=n_clusters) | |
| #pxart.save('temp.bmp') | |
| return pxart, usedseed | |
| # %% | |
| searchimage = gc.Image( | |
| # shape=(512, 512), | |
| label="Search image", type='pil') | |
| block_size = gc.Textbox( | |
| "0.01", | |
| label='Block Size ', | |
| placeholder="e.g. 8 for 8 pixels. 0.01 for 1% of min(w,h) (<1 for percentages, >= 1 for pixels)") | |
| palette_size = gc.Slider( | |
| 1, 1024, 32, step=1, label='Palette Size (Number of Colors)') | |
| hsv_weights = gc.Textbox( | |
| "0,0,1", | |
| label='HSV Weights. Weights of the channels when selecting a "representative pixel"/centroid from a cluster of pixels', | |
| placeholder='e.g. 0,0,1 to only consider the V channel (which seems to work well)') | |
| lcbr = gc.Slider( | |
| 3, 512, 51, step=2, label='Blur radius to calculate local contrast') | |
| seed = gc.Textbox( | |
| "", | |
| label='Seed for the random number generator (empty to randomize)', | |
| placeholder='e.g. 42') | |
| outimage = gc.Image( | |
| # shape=(224, 224), | |
| label="Output", type='pil') | |
| seedout = gc.Textbox(label='used seed') | |
| output_scaling = gc.Slider( | |
| 0, 16, 1, step=1, label='Output scaling factor') | |
| dither_amount = gc.Slider( | |
| 0, 255, 0, step=1, label='Dithering amount') | |
| gr.Interface( | |
| query, | |
| [searchimage, block_size, palette_size, hsv_weights, lcbr, seed, output_scaling, dither_amount], | |
| [outimage, seedout], | |
| title="kmeans-Pixartifier", | |
| description=f"Turns images into pixel art using kmeans clustering", | |
| analytics_enabled=False, | |
| allow_flagging='never', | |
| ).launch() | |