BrandSmith / sandbox.py
raznis's picture
Upload folder using huggingface_hub
c9f5fa7 verified
from PIL import Image, ImageDraw, ImageFont
from sklearn.cluster import KMeans
import numpy as np
import os
def extract_color_palette(image_path, num_colors=8, output_path=None):
"""
Extract color palette from an image and create a new image showing the palette.
Args:
image_path (str): Path to the input image
num_colors (int): Number of colors to extract (default: 8)
output_path (str): Path to save the output image (optional)
Returns:
PIL.Image: Image with white background showing the color palette
"""
# Load and process the image
try:
image = Image.open(image_path)
image = image.convert('RGB')
except Exception as e:
raise Exception(f"Error loading image: {e}")
# Resize image for faster processing (optional)
max_size = 500
if max(image.size) > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# Convert image to numpy array
data = np.array(image)
data = data.reshape((-1, 3))
# Use K-means clustering to find dominant colors
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
kmeans.fit(data)
# Get the color palette
palette = kmeans.cluster_centers_.astype(int)
# Sort colors by frequency (optional)
labels = kmeans.labels_
color_counts = np.bincount(labels)
sorted_indices = np.argsort(color_counts)[::-1]
palette = palette[sorted_indices]
# Create the palette image
palette_image = create_palette_image(palette, image_path)
# Save the output if path is provided
if output_path:
palette_image.save(output_path)
print(f"Palette saved to: {output_path}")
return palette_image
def create_palette_image(palette, original_image_path):
"""
Create an image showing the color palette with color swatches and hex values.
Args:
palette (numpy.ndarray): Array of RGB color values
original_image_path (str): Path to original image (for title)
Returns:
PIL.Image: Image with white background showing the palette
"""
# Image dimensions
swatch_width = 100
swatch_height = 80
margin = 20
text_height = 30
title_height = 40
cols = min(4, len(palette)) # Max 4 columns
rows = (len(palette) + cols - 1) // cols # Calculate rows needed
img_width = cols * swatch_width + (cols + 1) * margin
img_height = title_height + rows * (swatch_height + text_height) + (rows + 1) * margin
# Create white background image
palette_img = Image.new('RGB', (img_width, img_height), 'white')
draw = ImageDraw.Draw(palette_img)
# Try to load a font, fall back to default if not available
try:
title_font = ImageFont.truetype("arial.ttf", 16)
color_font = ImageFont.truetype("arial.ttf", 12)
except:
try:
title_font = ImageFont.load_default()
color_font = ImageFont.load_default()
except:
title_font = None
color_font = None
# Draw title
title = f"Color Palette - {os.path.basename(original_image_path)}"
if title_font:
bbox = draw.textbbox((0, 0), title, font=title_font)
text_width = bbox[2] - bbox[0]
title_x = (img_width - text_width) // 2
draw.text((title_x, 10), title, fill='black', font=title_font)
# Draw color swatches
for i, color in enumerate(palette):
row = i // cols
col = i % cols
# Calculate position
x = margin + col * (swatch_width + margin)
y = title_height + margin + row * (swatch_height + text_height + margin)
# Draw color swatch
color_tuple = tuple(color)
draw.rectangle([x, y, x + swatch_width, y + swatch_height],
fill=color_tuple, outline='gray', width=1)
# Convert to hex
hex_color = '#{:02x}{:02x}{:02x}'.format(color[0], color[1], color[2])
# Draw hex value below swatch
text_y = y + swatch_height + 5
if color_font:
bbox = draw.textbbox((0, 0), hex_color, font=color_font)
text_width = bbox[2] - bbox[0]
text_x = x + (swatch_width - text_width) // 2
draw.text((text_x, text_y), hex_color, fill='black', font=color_font)
# Draw RGB values
rgb_text = f"RGB({color[0]}, {color[1]}, {color[2]})"
if color_font:
bbox = draw.textbbox((0, 0), rgb_text, font=color_font)
text_width = bbox[2] - bbox[0]
text_x = x + (swatch_width - text_width) // 2
draw.text((text_x, text_y + 15), rgb_text, fill='black', font=color_font)
return palette_img
def rgb_to_hex(rgb):
"""Convert RGB tuple to hex string."""
return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])
# Example usage
if __name__ == "__main__":
# Example usage
input_image = "starbucks.png" # Replace with your image path
output_image = "color_palette-starbucks.png"
try:
# Extract palette and create visualization
palette_img = extract_color_palette(
image_path=input_image,
num_colors=8, # Adjust number of colors as needed
output_path=output_image
)
# Display the result (if running in Jupyter or with display capability)
# palette_img.show()
print("Color palette extraction completed successfully!")
except Exception as e:
print(f"Error: {e}")
print("Make sure you have the required packages installed:")
print("pip install pillow scikit-learn numpy")