Spaces:
Sleeping
Sleeping
| from sklearn.cluster import KMeans | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| from models import ColorPalette, ColorItem | |
| from agents import function_tool | |
| import math | |
| ### | |
| # another option to try is to use colorthief to extract the color palette | |
| # colorthief is a library that can extract the color palette from an image | |
| # it is much faster than k-means clustering | |
| ### | |
| #@function_tool | |
| def extract_color_palette(input_image: Image, num_colors: int) -> ColorPalette: | |
| """ | |
| Extract color palette from an image using K-means clustering. | |
| Args: | |
| input_image (PIL.Image): Input image | |
| num_colors (int): Number of colors to extract | |
| Returns: | |
| ColorPalette: object representing the color palette with the colors and their counts | |
| """ | |
| # Load and process the image | |
| try: | |
| # Convert to RGBA to handle transparency | |
| image = input_image.convert('RGBA') | |
| except Exception as e: | |
| raise Exception(f"Error loading image: {e}") | |
| try: | |
| # Resize image for faster processing (optional) | |
| max_size = 500 | |
| if max(image.size) > max_size: | |
| image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
| except Exception as e: | |
| raise Exception(f"Error resizing image: {e}") | |
| # Convert image to numpy array | |
| data = np.array(image) | |
| data = data.reshape((-1, 4)) # RGBA has 4 channels | |
| # Filter out transparent pixels (alpha < 128) | |
| non_transparent = data[data[:, 3] >= 128] | |
| # Convert back to RGB and normalize | |
| rgb_data = non_transparent[:, :3] | |
| # Use K-means clustering to find dominant colors | |
| kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10) | |
| kmeans.fit(rgb_data) | |
| # Get the color palette | |
| palette = kmeans.cluster_centers_.astype(int) | |
| # Sort colors by frequency (optional) | |
| labels = kmeans.labels_ | |
| color_counts = np.bincount(labels) | |
| total_count = sum(color_counts) | |
| palette_with_counts = [] | |
| for i in range(len(palette)): | |
| if color_counts[i] / total_count > 0.025: # Remove colors that appear in less than 1 percent of the image | |
| palette_with_counts.append(([int(palette[i][0]), int(palette[i][1]), int(palette[i][2]), int(color_counts[i])])) | |
| sorted_palette_with_counts = sorted(palette_with_counts, key=lambda x: x[3], reverse=True) | |
| return ColorPalette(colors=[ColorItem(R=color[0], G=color[1], B=color[2], count=color[3]) for color in sorted_palette_with_counts]) | |
| def create_color_palette_square_grid(color_data: ColorPalette, total_pixels: int = 160000): | |
| """ | |
| Creates an image of a square grid where each color is represented | |
| according to its frequency in the color_data input. | |
| Args: | |
| color_data: ColorPalette object containing the color palette | |
| total_pixels: Target number of pixels in the output image | |
| Returns: | |
| PIL Image object | |
| """ | |
| # Calculate total count | |
| total_count = sum(color.count for color in color_data.colors) | |
| # Calculate scaling factor to fit target pixel count | |
| scale_factor = total_pixels / total_count | |
| # Create pixel list based on frequencies | |
| pixel_list = [] | |
| for color in color_data.colors: | |
| scaled_count = int(color.count * scale_factor) | |
| pixel_list.extend([tuple([color.R, color.G, color.B])]* scaled_count) | |
| # Calculate square dimensions | |
| side_length = int(math.sqrt(len(pixel_list))) | |
| # Trim pixel list to make perfect square | |
| pixel_list = pixel_list[:side_length * side_length] | |
| # Create image | |
| img = Image.new('RGB', (side_length, side_length)) | |
| pixels = img.load() | |
| # Fill image with pixels | |
| for i, color in enumerate(pixel_list): | |
| x = i % side_length | |
| y = i // side_length | |
| pixels[x, y] = color | |
| return img | |
| def create_palette_image(palette: ColorPalette): | |
| """ | |
| Create an image showing the color palette with color swatches and hex values. | |
| Args: | |
| palette (numpy.ndarray): Array of RGB color values | |
| original_image_path (str): Path to original image (for title) | |
| Returns: | |
| PIL.Image: Image with white background showing the palette | |
| """ | |
| # Image dimensions | |
| swatch_width = 120 | |
| swatch_height = 100 | |
| margin = 20 | |
| text_height = 30 | |
| title_height = 40 | |
| cols = min(4, len(palette.colors)) # Max 4 columns | |
| rows = (len(palette.colors) + cols - 1) // cols # Calculate rows needed | |
| img_width = cols * swatch_width + (cols + 1) * margin | |
| img_height = title_height + rows * (swatch_height + text_height) + (rows + 1) * margin | |
| # Create white background image | |
| palette_img = Image.new('RGB', (img_width, img_height), 'white') | |
| draw = ImageDraw.Draw(palette_img) | |
| # Try to load a font, fall back to default if not available | |
| try: | |
| title_font = ImageFont.truetype("arial.ttf", 16) | |
| color_font = ImageFont.truetype("arial.ttf", 12) | |
| except: | |
| try: | |
| title_font = ImageFont.load_default() | |
| color_font = ImageFont.load_default() | |
| except: | |
| title_font = None | |
| color_font = None | |
| # Draw title | |
| title = f"Color Palette" | |
| if title_font: | |
| bbox = draw.textbbox((0, 0), title, font=title_font) | |
| text_width = bbox[2] - bbox[0] | |
| title_x = (img_width - text_width) // 2 | |
| draw.text((title_x, 10), title, fill='black', font=title_font) | |
| # Draw color swatches | |
| for i, color in enumerate(palette.colors): | |
| row = i // cols | |
| col = i % cols | |
| # Calculate position | |
| x = margin + col * (swatch_width + margin) | |
| y = title_height + margin + row * (swatch_height + text_height + margin) | |
| # Draw color swatch | |
| color_tuple = (color.R, color.G, color.B) | |
| draw.rectangle([x, y, x + swatch_width, y + swatch_height], | |
| fill=color_tuple, outline='gray', width=1) | |
| # Convert to hex | |
| hex_color = '#{:02x}{:02x}{:02x}'.format(color.R, color.G, color.B) | |
| # Draw hex value below swatch | |
| text_y = y + swatch_height + 5 | |
| if color_font: | |
| bbox = draw.textbbox((0, 0), hex_color, font=color_font) | |
| text_width = bbox[2] - bbox[0] | |
| text_x = x + (swatch_width - text_width) // 2 | |
| draw.text((text_x, text_y), hex_color, fill='black', font=color_font) | |
| # Draw RGB values | |
| rgb_text = f"RGB({color.R}, {color.G}, {color.B})" | |
| if color_font: | |
| bbox = draw.textbbox((0, 0), rgb_text, font=color_font) | |
| text_width = bbox[2] - bbox[0] | |
| text_x = x + (swatch_width - text_width) // 2 | |
| draw.text((text_x, text_y + 15), rgb_text, fill='black', font=color_font) | |
| return palette_img | |
| # Example usage: | |
| if __name__ == "__main__": | |
| colors = extract_color_palette("images/playstation-logo.png", 6) | |
| # Create square grid version | |
| square_img = create_color_palette_square_grid(colors) | |
| square_img.save('output/playstation_square.png') | |
| square_img.show() | |
| # Create palette image | |
| palette_img = create_palette_image(colors) | |
| palette_img.save('output/playstation_palette.png') | |
| palette_img.show() | |