Spaces:
Sleeping
Sleeping
File size: 7,347 Bytes
c9f5fa7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | from sklearn.cluster import KMeans
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from models import ColorPalette, ColorItem
from agents import function_tool
import math
###
# another option to try is to use colorthief to extract the color palette
# colorthief is a library that can extract the color palette from an image
# it is much faster than k-means clustering
###
#@function_tool
def extract_color_palette(input_image: Image, num_colors: int) -> ColorPalette:
"""
Extract color palette from an image using K-means clustering.
Args:
input_image (PIL.Image): Input image
num_colors (int): Number of colors to extract
Returns:
ColorPalette: object representing the color palette with the colors and their counts
"""
# Load and process the image
try:
# Convert to RGBA to handle transparency
image = input_image.convert('RGBA')
except Exception as e:
raise Exception(f"Error loading image: {e}")
try:
# Resize image for faster processing (optional)
max_size = 500
if max(image.size) > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
except Exception as e:
raise Exception(f"Error resizing image: {e}")
# Convert image to numpy array
data = np.array(image)
data = data.reshape((-1, 4)) # RGBA has 4 channels
# Filter out transparent pixels (alpha < 128)
non_transparent = data[data[:, 3] >= 128]
# Convert back to RGB and normalize
rgb_data = non_transparent[:, :3]
# Use K-means clustering to find dominant colors
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
kmeans.fit(rgb_data)
# Get the color palette
palette = kmeans.cluster_centers_.astype(int)
# Sort colors by frequency (optional)
labels = kmeans.labels_
color_counts = np.bincount(labels)
total_count = sum(color_counts)
palette_with_counts = []
for i in range(len(palette)):
if color_counts[i] / total_count > 0.025: # Remove colors that appear in less than 1 percent of the image
palette_with_counts.append(([int(palette[i][0]), int(palette[i][1]), int(palette[i][2]), int(color_counts[i])]))
sorted_palette_with_counts = sorted(palette_with_counts, key=lambda x: x[3], reverse=True)
return ColorPalette(colors=[ColorItem(R=color[0], G=color[1], B=color[2], count=color[3]) for color in sorted_palette_with_counts])
def create_color_palette_square_grid(color_data: ColorPalette, total_pixels: int = 160000):
"""
Creates an image of a square grid where each color is represented
according to its frequency in the color_data input.
Args:
color_data: ColorPalette object containing the color palette
total_pixels: Target number of pixels in the output image
Returns:
PIL Image object
"""
# Calculate total count
total_count = sum(color.count for color in color_data.colors)
# Calculate scaling factor to fit target pixel count
scale_factor = total_pixels / total_count
# Create pixel list based on frequencies
pixel_list = []
for color in color_data.colors:
scaled_count = int(color.count * scale_factor)
pixel_list.extend([tuple([color.R, color.G, color.B])]* scaled_count)
# Calculate square dimensions
side_length = int(math.sqrt(len(pixel_list)))
# Trim pixel list to make perfect square
pixel_list = pixel_list[:side_length * side_length]
# Create image
img = Image.new('RGB', (side_length, side_length))
pixels = img.load()
# Fill image with pixels
for i, color in enumerate(pixel_list):
x = i % side_length
y = i // side_length
pixels[x, y] = color
return img
def create_palette_image(palette: ColorPalette):
"""
Create an image showing the color palette with color swatches and hex values.
Args:
palette (numpy.ndarray): Array of RGB color values
original_image_path (str): Path to original image (for title)
Returns:
PIL.Image: Image with white background showing the palette
"""
# Image dimensions
swatch_width = 120
swatch_height = 100
margin = 20
text_height = 30
title_height = 40
cols = min(4, len(palette.colors)) # Max 4 columns
rows = (len(palette.colors) + cols - 1) // cols # Calculate rows needed
img_width = cols * swatch_width + (cols + 1) * margin
img_height = title_height + rows * (swatch_height + text_height) + (rows + 1) * margin
# Create white background image
palette_img = Image.new('RGB', (img_width, img_height), 'white')
draw = ImageDraw.Draw(palette_img)
# Try to load a font, fall back to default if not available
try:
title_font = ImageFont.truetype("arial.ttf", 16)
color_font = ImageFont.truetype("arial.ttf", 12)
except:
try:
title_font = ImageFont.load_default()
color_font = ImageFont.load_default()
except:
title_font = None
color_font = None
# Draw title
title = f"Color Palette"
if title_font:
bbox = draw.textbbox((0, 0), title, font=title_font)
text_width = bbox[2] - bbox[0]
title_x = (img_width - text_width) // 2
draw.text((title_x, 10), title, fill='black', font=title_font)
# Draw color swatches
for i, color in enumerate(palette.colors):
row = i // cols
col = i % cols
# Calculate position
x = margin + col * (swatch_width + margin)
y = title_height + margin + row * (swatch_height + text_height + margin)
# Draw color swatch
color_tuple = (color.R, color.G, color.B)
draw.rectangle([x, y, x + swatch_width, y + swatch_height],
fill=color_tuple, outline='gray', width=1)
# Convert to hex
hex_color = '#{:02x}{:02x}{:02x}'.format(color.R, color.G, color.B)
# Draw hex value below swatch
text_y = y + swatch_height + 5
if color_font:
bbox = draw.textbbox((0, 0), hex_color, font=color_font)
text_width = bbox[2] - bbox[0]
text_x = x + (swatch_width - text_width) // 2
draw.text((text_x, text_y), hex_color, fill='black', font=color_font)
# Draw RGB values
rgb_text = f"RGB({color.R}, {color.G}, {color.B})"
if color_font:
bbox = draw.textbbox((0, 0), rgb_text, font=color_font)
text_width = bbox[2] - bbox[0]
text_x = x + (swatch_width - text_width) // 2
draw.text((text_x, text_y + 15), rgb_text, fill='black', font=color_font)
return palette_img
# Example usage:
if __name__ == "__main__":
colors = extract_color_palette("images/playstation-logo.png", 6)
# Create square grid version
square_img = create_color_palette_square_grid(colors)
square_img.save('output/playstation_square.png')
square_img.show()
# Create palette image
palette_img = create_palette_image(colors)
palette_img.save('output/playstation_palette.png')
palette_img.show()
|