BrandSmith / src /tools.py
raznis's picture
Upload folder using huggingface_hub
c9f5fa7 verified
from sklearn.cluster import KMeans
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from models import ColorPalette, ColorItem
from agents import function_tool
import math
###
# another option to try is to use colorthief to extract the color palette
# colorthief is a library that can extract the color palette from an image
# it is much faster than k-means clustering
###
#@function_tool
def extract_color_palette(input_image: Image, num_colors: int) -> ColorPalette:
"""
Extract color palette from an image using K-means clustering.
Args:
input_image (PIL.Image): Input image
num_colors (int): Number of colors to extract
Returns:
ColorPalette: object representing the color palette with the colors and their counts
"""
# Load and process the image
try:
# Convert to RGBA to handle transparency
image = input_image.convert('RGBA')
except Exception as e:
raise Exception(f"Error loading image: {e}")
try:
# Resize image for faster processing (optional)
max_size = 500
if max(image.size) > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
except Exception as e:
raise Exception(f"Error resizing image: {e}")
# Convert image to numpy array
data = np.array(image)
data = data.reshape((-1, 4)) # RGBA has 4 channels
# Filter out transparent pixels (alpha < 128)
non_transparent = data[data[:, 3] >= 128]
# Convert back to RGB and normalize
rgb_data = non_transparent[:, :3]
# Use K-means clustering to find dominant colors
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
kmeans.fit(rgb_data)
# Get the color palette
palette = kmeans.cluster_centers_.astype(int)
# Sort colors by frequency (optional)
labels = kmeans.labels_
color_counts = np.bincount(labels)
total_count = sum(color_counts)
palette_with_counts = []
for i in range(len(palette)):
if color_counts[i] / total_count > 0.025: # Remove colors that appear in less than 1 percent of the image
palette_with_counts.append(([int(palette[i][0]), int(palette[i][1]), int(palette[i][2]), int(color_counts[i])]))
sorted_palette_with_counts = sorted(palette_with_counts, key=lambda x: x[3], reverse=True)
return ColorPalette(colors=[ColorItem(R=color[0], G=color[1], B=color[2], count=color[3]) for color in sorted_palette_with_counts])
def create_color_palette_square_grid(color_data: ColorPalette, total_pixels: int = 160000):
"""
Creates an image of a square grid where each color is represented
according to its frequency in the color_data input.
Args:
color_data: ColorPalette object containing the color palette
total_pixels: Target number of pixels in the output image
Returns:
PIL Image object
"""
# Calculate total count
total_count = sum(color.count for color in color_data.colors)
# Calculate scaling factor to fit target pixel count
scale_factor = total_pixels / total_count
# Create pixel list based on frequencies
pixel_list = []
for color in color_data.colors:
scaled_count = int(color.count * scale_factor)
pixel_list.extend([tuple([color.R, color.G, color.B])]* scaled_count)
# Calculate square dimensions
side_length = int(math.sqrt(len(pixel_list)))
# Trim pixel list to make perfect square
pixel_list = pixel_list[:side_length * side_length]
# Create image
img = Image.new('RGB', (side_length, side_length))
pixels = img.load()
# Fill image with pixels
for i, color in enumerate(pixel_list):
x = i % side_length
y = i // side_length
pixels[x, y] = color
return img
def create_palette_image(palette: ColorPalette):
"""
Create an image showing the color palette with color swatches and hex values.
Args:
palette (numpy.ndarray): Array of RGB color values
original_image_path (str): Path to original image (for title)
Returns:
PIL.Image: Image with white background showing the palette
"""
# Image dimensions
swatch_width = 120
swatch_height = 100
margin = 20
text_height = 30
title_height = 40
cols = min(4, len(palette.colors)) # Max 4 columns
rows = (len(palette.colors) + cols - 1) // cols # Calculate rows needed
img_width = cols * swatch_width + (cols + 1) * margin
img_height = title_height + rows * (swatch_height + text_height) + (rows + 1) * margin
# Create white background image
palette_img = Image.new('RGB', (img_width, img_height), 'white')
draw = ImageDraw.Draw(palette_img)
# Try to load a font, fall back to default if not available
try:
title_font = ImageFont.truetype("arial.ttf", 16)
color_font = ImageFont.truetype("arial.ttf", 12)
except:
try:
title_font = ImageFont.load_default()
color_font = ImageFont.load_default()
except:
title_font = None
color_font = None
# Draw title
title = f"Color Palette"
if title_font:
bbox = draw.textbbox((0, 0), title, font=title_font)
text_width = bbox[2] - bbox[0]
title_x = (img_width - text_width) // 2
draw.text((title_x, 10), title, fill='black', font=title_font)
# Draw color swatches
for i, color in enumerate(palette.colors):
row = i // cols
col = i % cols
# Calculate position
x = margin + col * (swatch_width + margin)
y = title_height + margin + row * (swatch_height + text_height + margin)
# Draw color swatch
color_tuple = (color.R, color.G, color.B)
draw.rectangle([x, y, x + swatch_width, y + swatch_height],
fill=color_tuple, outline='gray', width=1)
# Convert to hex
hex_color = '#{:02x}{:02x}{:02x}'.format(color.R, color.G, color.B)
# Draw hex value below swatch
text_y = y + swatch_height + 5
if color_font:
bbox = draw.textbbox((0, 0), hex_color, font=color_font)
text_width = bbox[2] - bbox[0]
text_x = x + (swatch_width - text_width) // 2
draw.text((text_x, text_y), hex_color, fill='black', font=color_font)
# Draw RGB values
rgb_text = f"RGB({color.R}, {color.G}, {color.B})"
if color_font:
bbox = draw.textbbox((0, 0), rgb_text, font=color_font)
text_width = bbox[2] - bbox[0]
text_x = x + (swatch_width - text_width) // 2
draw.text((text_x, text_y + 15), rgb_text, fill='black', font=color_font)
return palette_img
# Example usage:
if __name__ == "__main__":
colors = extract_color_palette("images/playstation-logo.png", 6)
# Create square grid version
square_img = create_color_palette_square_grid(colors)
square_img.save('output/playstation_square.png')
square_img.show()
# Create palette image
palette_img = create_palette_image(colors)
palette_img.save('output/playstation_palette.png')
palette_img.show()