from sklearn.cluster import KMeans import numpy as np from PIL import Image, ImageDraw, ImageFont from models import ColorPalette, ColorItem from agents import function_tool import math ### # another option to try is to use colorthief to extract the color palette # colorthief is a library that can extract the color palette from an image # it is much faster than k-means clustering ### #@function_tool def extract_color_palette(input_image: Image, num_colors: int) -> ColorPalette: """ Extract color palette from an image using K-means clustering. Args: input_image (PIL.Image): Input image num_colors (int): Number of colors to extract Returns: ColorPalette: object representing the color palette with the colors and their counts """ # Load and process the image try: # Convert to RGBA to handle transparency image = input_image.convert('RGBA') except Exception as e: raise Exception(f"Error loading image: {e}") try: # Resize image for faster processing (optional) max_size = 500 if max(image.size) > max_size: image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) except Exception as e: raise Exception(f"Error resizing image: {e}") # Convert image to numpy array data = np.array(image) data = data.reshape((-1, 4)) # RGBA has 4 channels # Filter out transparent pixels (alpha < 128) non_transparent = data[data[:, 3] >= 128] # Convert back to RGB and normalize rgb_data = non_transparent[:, :3] # Use K-means clustering to find dominant colors kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10) kmeans.fit(rgb_data) # Get the color palette palette = kmeans.cluster_centers_.astype(int) # Sort colors by frequency (optional) labels = kmeans.labels_ color_counts = np.bincount(labels) total_count = sum(color_counts) palette_with_counts = [] for i in range(len(palette)): if color_counts[i] / total_count > 0.025: # Remove colors that appear in less than 1 percent of the image palette_with_counts.append(([int(palette[i][0]), int(palette[i][1]), int(palette[i][2]), int(color_counts[i])])) sorted_palette_with_counts = sorted(palette_with_counts, key=lambda x: x[3], reverse=True) return ColorPalette(colors=[ColorItem(R=color[0], G=color[1], B=color[2], count=color[3]) for color in sorted_palette_with_counts]) def create_color_palette_square_grid(color_data: ColorPalette, total_pixels: int = 160000): """ Creates an image of a square grid where each color is represented according to its frequency in the color_data input. Args: color_data: ColorPalette object containing the color palette total_pixels: Target number of pixels in the output image Returns: PIL Image object """ # Calculate total count total_count = sum(color.count for color in color_data.colors) # Calculate scaling factor to fit target pixel count scale_factor = total_pixels / total_count # Create pixel list based on frequencies pixel_list = [] for color in color_data.colors: scaled_count = int(color.count * scale_factor) pixel_list.extend([tuple([color.R, color.G, color.B])]* scaled_count) # Calculate square dimensions side_length = int(math.sqrt(len(pixel_list))) # Trim pixel list to make perfect square pixel_list = pixel_list[:side_length * side_length] # Create image img = Image.new('RGB', (side_length, side_length)) pixels = img.load() # Fill image with pixels for i, color in enumerate(pixel_list): x = i % side_length y = i // side_length pixels[x, y] = color return img def create_palette_image(palette: ColorPalette): """ Create an image showing the color palette with color swatches and hex values. Args: palette (numpy.ndarray): Array of RGB color values original_image_path (str): Path to original image (for title) Returns: PIL.Image: Image with white background showing the palette """ # Image dimensions swatch_width = 120 swatch_height = 100 margin = 20 text_height = 30 title_height = 40 cols = min(4, len(palette.colors)) # Max 4 columns rows = (len(palette.colors) + cols - 1) // cols # Calculate rows needed img_width = cols * swatch_width + (cols + 1) * margin img_height = title_height + rows * (swatch_height + text_height) + (rows + 1) * margin # Create white background image palette_img = Image.new('RGB', (img_width, img_height), 'white') draw = ImageDraw.Draw(palette_img) # Try to load a font, fall back to default if not available try: title_font = ImageFont.truetype("arial.ttf", 16) color_font = ImageFont.truetype("arial.ttf", 12) except: try: title_font = ImageFont.load_default() color_font = ImageFont.load_default() except: title_font = None color_font = None # Draw title title = f"Color Palette" if title_font: bbox = draw.textbbox((0, 0), title, font=title_font) text_width = bbox[2] - bbox[0] title_x = (img_width - text_width) // 2 draw.text((title_x, 10), title, fill='black', font=title_font) # Draw color swatches for i, color in enumerate(palette.colors): row = i // cols col = i % cols # Calculate position x = margin + col * (swatch_width + margin) y = title_height + margin + row * (swatch_height + text_height + margin) # Draw color swatch color_tuple = (color.R, color.G, color.B) draw.rectangle([x, y, x + swatch_width, y + swatch_height], fill=color_tuple, outline='gray', width=1) # Convert to hex hex_color = '#{:02x}{:02x}{:02x}'.format(color.R, color.G, color.B) # Draw hex value below swatch text_y = y + swatch_height + 5 if color_font: bbox = draw.textbbox((0, 0), hex_color, font=color_font) text_width = bbox[2] - bbox[0] text_x = x + (swatch_width - text_width) // 2 draw.text((text_x, text_y), hex_color, fill='black', font=color_font) # Draw RGB values rgb_text = f"RGB({color.R}, {color.G}, {color.B})" if color_font: bbox = draw.textbbox((0, 0), rgb_text, font=color_font) text_width = bbox[2] - bbox[0] text_x = x + (swatch_width - text_width) // 2 draw.text((text_x, text_y + 15), rgb_text, fill='black', font=color_font) return palette_img # Example usage: if __name__ == "__main__": colors = extract_color_palette("images/playstation-logo.png", 6) # Create square grid version square_img = create_color_palette_square_grid(colors) square_img.save('output/playstation_square.png') square_img.show() # Create palette image palette_img = create_palette_image(colors) palette_img.save('output/playstation_palette.png') palette_img.show()