File size: 7,347 Bytes
c9f5fa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from sklearn.cluster import KMeans
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from models import ColorPalette, ColorItem
from agents import function_tool
import math

###
# another option to try is to use colorthief to extract the color palette
# colorthief is a library that can extract the color palette from an image
# it is much faster than k-means clustering
###

#@function_tool
def extract_color_palette(input_image: Image, num_colors: int) -> ColorPalette:
    """
    Extract color palette from an image using K-means clustering.
    
    Args:
        input_image (PIL.Image): Input image
        num_colors (int): Number of colors to extract
    
    Returns:
        ColorPalette: object representing the color palette with the colors and their counts
    """
    # Load and process the image
    try:
        # Convert to RGBA to handle transparency
        image = input_image.convert('RGBA')
    except Exception as e:
        raise Exception(f"Error loading image: {e}")
    
    try:
        # Resize image for faster processing (optional)
        max_size = 500
        if max(image.size) > max_size:
            image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
    except Exception as e:
        raise Exception(f"Error resizing image: {e}")
    
    # Convert image to numpy array
    data = np.array(image)
    data = data.reshape((-1, 4))  # RGBA has 4 channels
    
    # Filter out transparent pixels (alpha < 128)
    non_transparent = data[data[:, 3] >= 128]
    
    # Convert back to RGB and normalize
    rgb_data = non_transparent[:, :3]
    
    # Use K-means clustering to find dominant colors
    kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
    kmeans.fit(rgb_data)
    
    # Get the color palette
    palette = kmeans.cluster_centers_.astype(int)
    
    # Sort colors by frequency (optional)
    labels = kmeans.labels_
    color_counts = np.bincount(labels)
    total_count = sum(color_counts)
    palette_with_counts = []
    for i in range(len(palette)):
        if color_counts[i] / total_count > 0.025:   # Remove colors that appear in less than 1 percent of the image
            palette_with_counts.append(([int(palette[i][0]), int(palette[i][1]), int(palette[i][2]), int(color_counts[i])]))
    sorted_palette_with_counts = sorted(palette_with_counts, key=lambda x: x[3], reverse=True)
    
    return ColorPalette(colors=[ColorItem(R=color[0], G=color[1], B=color[2], count=color[3]) for color in sorted_palette_with_counts])




def create_color_palette_square_grid(color_data: ColorPalette, total_pixels: int = 160000):
    """
    Creates an image of a square grid where each color is represented
    according to its frequency in the color_data input.
    
    Args:
        color_data: ColorPalette object containing the color palette
        total_pixels: Target number of pixels in the output image
    
    Returns:
        PIL Image object
    """
    # Calculate total count
    total_count = sum(color.count for color in color_data.colors)
    
    # Calculate scaling factor to fit target pixel count
    scale_factor = total_pixels / total_count
    
    # Create pixel list based on frequencies
    pixel_list = []
    for color in color_data.colors:
        scaled_count = int(color.count * scale_factor)
        pixel_list.extend([tuple([color.R, color.G, color.B])]* scaled_count)
    
    # Calculate square dimensions
    side_length = int(math.sqrt(len(pixel_list)))
    
    # Trim pixel list to make perfect square
    pixel_list = pixel_list[:side_length * side_length]
    
    # Create image
    img = Image.new('RGB', (side_length, side_length))
    pixels = img.load()
    
    # Fill image with pixels
    for i, color in enumerate(pixel_list):
        x = i % side_length
        y = i // side_length
        pixels[x, y] = color
    
    return img


def create_palette_image(palette: ColorPalette):
    """
    Create an image showing the color palette with color swatches and hex values.
    
    Args:
        palette (numpy.ndarray): Array of RGB color values
        original_image_path (str): Path to original image (for title)
    
    Returns:
        PIL.Image: Image with white background showing the palette
    """
    
    # Image dimensions
    swatch_width = 120
    swatch_height = 100
    margin = 20
    text_height = 30
    title_height = 40
    
    cols = min(4, len(palette.colors))  # Max 4 columns
    rows = (len(palette.colors) + cols - 1) // cols  # Calculate rows needed
    
    img_width = cols * swatch_width + (cols + 1) * margin
    img_height = title_height + rows * (swatch_height + text_height) + (rows + 1) * margin
    
    # Create white background image
    palette_img = Image.new('RGB', (img_width, img_height), 'white')
    draw = ImageDraw.Draw(palette_img)
    
    # Try to load a font, fall back to default if not available
    try:
        title_font = ImageFont.truetype("arial.ttf", 16)
        color_font = ImageFont.truetype("arial.ttf", 12)
    except:
        try:
            title_font = ImageFont.load_default()
            color_font = ImageFont.load_default()
        except:
            title_font = None
            color_font = None
    
    # Draw title
    title = f"Color Palette"
    if title_font:
        bbox = draw.textbbox((0, 0), title, font=title_font)
        text_width = bbox[2] - bbox[0]
        title_x = (img_width - text_width) // 2
        draw.text((title_x, 10), title, fill='black', font=title_font)
    
    # Draw color swatches
    for i, color in enumerate(palette.colors):
        row = i // cols
        col = i % cols
        
        # Calculate position
        x = margin + col * (swatch_width + margin)
        y = title_height + margin + row * (swatch_height + text_height + margin)
        
        # Draw color swatch
        color_tuple = (color.R, color.G, color.B)
        draw.rectangle([x, y, x + swatch_width, y + swatch_height], 
                      fill=color_tuple, outline='gray', width=1)
        
        # Convert to hex
        hex_color = '#{:02x}{:02x}{:02x}'.format(color.R, color.G, color.B)
        
        # Draw hex value below swatch
        text_y = y + swatch_height + 5
        if color_font:
            bbox = draw.textbbox((0, 0), hex_color, font=color_font)
            text_width = bbox[2] - bbox[0]
            text_x = x + (swatch_width - text_width) // 2
            draw.text((text_x, text_y), hex_color, fill='black', font=color_font)
        
        # Draw RGB values
        rgb_text = f"RGB({color.R}, {color.G}, {color.B})"
        if color_font:
            bbox = draw.textbbox((0, 0), rgb_text, font=color_font)
            text_width = bbox[2] - bbox[0]
            text_x = x + (swatch_width - text_width) // 2
            draw.text((text_x, text_y + 15), rgb_text, fill='black', font=color_font)
    
    return palette_img




# Example usage:
if __name__ == "__main__":
    colors = extract_color_palette("images/playstation-logo.png", 6)
    
    # Create square grid version
    square_img = create_color_palette_square_grid(colors)
    square_img.save('output/playstation_square.png')
    square_img.show()

    # Create palette image
    palette_img = create_palette_image(colors)
    palette_img.save('output/playstation_palette.png')
    palette_img.show()