Spaces:
Sleeping
Sleeping
| from PIL import Image, ImageDraw, ImageFont | |
| from sklearn.cluster import KMeans | |
| import numpy as np | |
| import os | |
| def extract_color_palette(image_path, num_colors=8, output_path=None): | |
| """ | |
| Extract color palette from an image and create a new image showing the palette. | |
| Args: | |
| image_path (str): Path to the input image | |
| num_colors (int): Number of colors to extract (default: 8) | |
| output_path (str): Path to save the output image (optional) | |
| Returns: | |
| PIL.Image: Image with white background showing the color palette | |
| """ | |
| # Load and process the image | |
| try: | |
| image = Image.open(image_path) | |
| image = image.convert('RGB') | |
| except Exception as e: | |
| raise Exception(f"Error loading image: {e}") | |
| # Resize image for faster processing (optional) | |
| max_size = 500 | |
| if max(image.size) > max_size: | |
| image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
| # Convert image to numpy array | |
| data = np.array(image) | |
| data = data.reshape((-1, 3)) | |
| # Use K-means clustering to find dominant colors | |
| kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10) | |
| kmeans.fit(data) | |
| # Get the color palette | |
| palette = kmeans.cluster_centers_.astype(int) | |
| # Sort colors by frequency (optional) | |
| labels = kmeans.labels_ | |
| color_counts = np.bincount(labels) | |
| sorted_indices = np.argsort(color_counts)[::-1] | |
| palette = palette[sorted_indices] | |
| # Create the palette image | |
| palette_image = create_palette_image(palette, image_path) | |
| # Save the output if path is provided | |
| if output_path: | |
| palette_image.save(output_path) | |
| print(f"Palette saved to: {output_path}") | |
| return palette_image | |
| def create_palette_image(palette, original_image_path): | |
| """ | |
| Create an image showing the color palette with color swatches and hex values. | |
| Args: | |
| palette (numpy.ndarray): Array of RGB color values | |
| original_image_path (str): Path to original image (for title) | |
| Returns: | |
| PIL.Image: Image with white background showing the palette | |
| """ | |
| # Image dimensions | |
| swatch_width = 100 | |
| swatch_height = 80 | |
| margin = 20 | |
| text_height = 30 | |
| title_height = 40 | |
| cols = min(4, len(palette)) # Max 4 columns | |
| rows = (len(palette) + cols - 1) // cols # Calculate rows needed | |
| img_width = cols * swatch_width + (cols + 1) * margin | |
| img_height = title_height + rows * (swatch_height + text_height) + (rows + 1) * margin | |
| # Create white background image | |
| palette_img = Image.new('RGB', (img_width, img_height), 'white') | |
| draw = ImageDraw.Draw(palette_img) | |
| # Try to load a font, fall back to default if not available | |
| try: | |
| title_font = ImageFont.truetype("arial.ttf", 16) | |
| color_font = ImageFont.truetype("arial.ttf", 12) | |
| except: | |
| try: | |
| title_font = ImageFont.load_default() | |
| color_font = ImageFont.load_default() | |
| except: | |
| title_font = None | |
| color_font = None | |
| # Draw title | |
| title = f"Color Palette - {os.path.basename(original_image_path)}" | |
| if title_font: | |
| bbox = draw.textbbox((0, 0), title, font=title_font) | |
| text_width = bbox[2] - bbox[0] | |
| title_x = (img_width - text_width) // 2 | |
| draw.text((title_x, 10), title, fill='black', font=title_font) | |
| # Draw color swatches | |
| for i, color in enumerate(palette): | |
| row = i // cols | |
| col = i % cols | |
| # Calculate position | |
| x = margin + col * (swatch_width + margin) | |
| y = title_height + margin + row * (swatch_height + text_height + margin) | |
| # Draw color swatch | |
| color_tuple = tuple(color) | |
| draw.rectangle([x, y, x + swatch_width, y + swatch_height], | |
| fill=color_tuple, outline='gray', width=1) | |
| # Convert to hex | |
| hex_color = '#{:02x}{:02x}{:02x}'.format(color[0], color[1], color[2]) | |
| # Draw hex value below swatch | |
| text_y = y + swatch_height + 5 | |
| if color_font: | |
| bbox = draw.textbbox((0, 0), hex_color, font=color_font) | |
| text_width = bbox[2] - bbox[0] | |
| text_x = x + (swatch_width - text_width) // 2 | |
| draw.text((text_x, text_y), hex_color, fill='black', font=color_font) | |
| # Draw RGB values | |
| rgb_text = f"RGB({color[0]}, {color[1]}, {color[2]})" | |
| if color_font: | |
| bbox = draw.textbbox((0, 0), rgb_text, font=color_font) | |
| text_width = bbox[2] - bbox[0] | |
| text_x = x + (swatch_width - text_width) // 2 | |
| draw.text((text_x, text_y + 15), rgb_text, fill='black', font=color_font) | |
| return palette_img | |
| def rgb_to_hex(rgb): | |
| """Convert RGB tuple to hex string.""" | |
| return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2]) | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Example usage | |
| input_image = "starbucks.png" # Replace with your image path | |
| output_image = "color_palette-starbucks.png" | |
| try: | |
| # Extract palette and create visualization | |
| palette_img = extract_color_palette( | |
| image_path=input_image, | |
| num_colors=8, # Adjust number of colors as needed | |
| output_path=output_image | |
| ) | |
| # Display the result (if running in Jupyter or with display capability) | |
| # palette_img.show() | |
| print("Color palette extraction completed successfully!") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| print("Make sure you have the required packages installed:") | |
| print("pip install pillow scikit-learn numpy") |