official.ghost.logic commited on
Commit
3aa90e6
·
0 Parent(s):

Initial commit: Add project structure and core modules

Browse files

- Add color palette extraction with K-means clustering
- Add color theory harmonies (complementary, triadic, analogous, etc.)
- Add palette swapping with nearest neighbor matching
- Add palette visualization tools
- Set up project structure with README and requirements

Files changed (5) hide show
  1. .gitignore +51 -0
  2. README.md +35 -0
  3. color_palette.py +329 -0
  4. image_generator.py +100 -0
  5. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ dist/
9
+ *.egg-info/
10
+ .eggs/
11
+
12
+ # Virtual Environment
13
+ venv/
14
+ env/
15
+ ENV/
16
+ .venv
17
+
18
+ # Environment variables
19
+ .env
20
+ .env.local
21
+
22
+ # IDE
23
+ .vscode/
24
+ .idea/
25
+ *.swp
26
+ *.swo
27
+ *~
28
+
29
+ # OS
30
+ .DS_Store
31
+ .DS_Store?
32
+ ._*
33
+ Thumbs.db
34
+
35
+ # Gradio
36
+ gradio_cached_examples/
37
+ flagged/
38
+
39
+ # Generated images
40
+ generated_images/
41
+ output/
42
+ temp/
43
+
44
+ # Model cache
45
+ transformers_cache/
46
+ huggingface_cache/
47
+ diffusers_cache/
48
+
49
+ # Logs
50
+ *.log
51
+ logs/
README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AI Image Editor
2
+
3
+ An AI-powered image editor with color palette extraction and swapping capabilities.
4
+
5
+ ## Features
6
+
7
+ - **AI Image Generation**: Create images from text prompts using Stable Diffusion
8
+ - **Color Palette Extraction**: Automatically extract dominant colors from images
9
+ - **Palette Swapping**: Replace colors in images with custom palettes
10
+ - **Color Theory Tools**: Apply color harmony rules (complementary, analogous, triadic)
11
+ - **Interactive UI**: Easy-to-use Gradio interface
12
+
13
+ ## Installation
14
+
15
+ ```bash
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+ ## Usage
20
+
21
+ ```bash
22
+ python app.py
23
+ ```
24
+
25
+ ## Tech Stack
26
+
27
+ - **Gradio**: Web interface
28
+ - **Diffusers**: Stable Diffusion image generation
29
+ - **Pillow**: Image processing
30
+ - **scikit-learn**: K-means clustering for palette extraction
31
+ - **NumPy**: Numerical operations
32
+
33
+ ## License
34
+
35
+ MIT
color_palette.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Color Palette Extraction and Manipulation Module
3
+ Implements patterns from Pylette, colorgram, and paletteswapper
4
+ """
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from sklearn.cluster import KMeans
9
+ from typing import List, Tuple
10
+ import colorsys
11
+
12
+
13
+ class ColorPalette:
14
+ """
15
+ Handle color palette extraction and manipulation
16
+ Based on research from top GitHub repos:
17
+ - Pylette: KMeans extraction with metadata
18
+ - colorgram.py: Fast extraction with proportions
19
+ - paletteswapper: Clean functional palette replacement
20
+ """
21
+
22
+ @staticmethod
23
+ def extract_palette(
24
+ image: Image.Image,
25
+ n_colors: int = 5,
26
+ sort_by: str = 'frequency' # 'frequency' or 'luminance'
27
+ ) -> List[Tuple[int, int, int]]:
28
+ """
29
+ Extract dominant colors using K-means clustering
30
+ Pattern from Pylette and scikit-learn tutorials
31
+
32
+ Args:
33
+ image: PIL Image
34
+ n_colors: Number of colors to extract
35
+ sort_by: How to sort results ('frequency' or 'luminance')
36
+
37
+ Returns:
38
+ List of RGB tuples sorted by specified criterion
39
+ """
40
+ # Resize for faster processing (pattern from colorgram.py)
41
+ img = image.copy()
42
+ img.thumbnail((200, 200))
43
+ img = img.convert('RGB')
44
+
45
+ # Flatten image to pixel array
46
+ pixels = np.array(img).reshape(-1, 3)
47
+
48
+ # K-means clustering (industry standard from scikit-learn)
49
+ kmeans = KMeans(
50
+ n_clusters=n_colors,
51
+ random_state=42,
52
+ n_init=10,
53
+ max_iter=300
54
+ )
55
+ kmeans.fit(pixels)
56
+
57
+ # Get cluster centers (dominant colors)
58
+ colors = kmeans.cluster_centers_.astype(int)
59
+ labels = kmeans.labels_
60
+ counts = np.bincount(labels)
61
+
62
+ # Calculate proportions (pattern from colorgram.py)
63
+ total_pixels = len(labels)
64
+ proportions = counts / total_pixels
65
+
66
+ # Sort based on criterion
67
+ if sort_by == 'luminance':
68
+ # Calculate luminance using standard formula
69
+ luminances = 0.299 * colors[:, 0] + 0.587 * colors[:, 1] + 0.114 * colors[:, 2]
70
+ sorted_indices = np.argsort(-luminances)
71
+ else: # frequency
72
+ sorted_indices = np.argsort(-counts)
73
+
74
+ sorted_colors = colors[sorted_indices]
75
+
76
+ return [tuple(color) for color in sorted_colors]
77
+
78
+ @staticmethod
79
+ def swap_palette(
80
+ image: Image.Image,
81
+ source_palette: List[Tuple[int, int, int]],
82
+ target_palette: List[Tuple[int, int, int]],
83
+ mode: str = 'closest' # 'closest' or 'threshold'
84
+ ) -> Image.Image:
85
+ """
86
+ Replace colors in image from source to target palette
87
+ Pattern from paletteswapper with nearest neighbor matching
88
+
89
+ Args:
90
+ image: PIL Image to modify
91
+ source_palette: Original colors to replace
92
+ target_palette: New colors to use
93
+ mode: 'closest' for all pixels, 'threshold' for selective
94
+
95
+ Returns:
96
+ Modified PIL Image
97
+ """
98
+ img_array = np.array(image.convert('RGB')).astype(float)
99
+ height, width, _ = img_array.shape
100
+ pixels = img_array.reshape(-1, 3)
101
+ result_pixels = pixels.copy()
102
+
103
+ # Map each source color to target color
104
+ min_len = min(len(source_palette), len(target_palette))
105
+
106
+ for i in range(min_len):
107
+ source_color = np.array(source_palette[i])
108
+ target_color = np.array(target_palette[i])
109
+
110
+ # Calculate Euclidean distance to source color
111
+ distances = np.linalg.norm(pixels - source_color, axis=1)
112
+
113
+ if mode == 'threshold':
114
+ # Only replace pixels very close to source color
115
+ threshold = np.percentile(distances, 15)
116
+ mask = distances <= threshold
117
+ else: # closest
118
+ # For each pixel, check if this is the closest source color
119
+ all_distances = np.array([
120
+ np.linalg.norm(pixels - np.array(src), axis=1)
121
+ for src in source_palette[:min_len]
122
+ ])
123
+ closest_color_idx = np.argmin(all_distances, axis=0)
124
+ mask = closest_color_idx == i
125
+
126
+ result_pixels[mask] = target_color
127
+
128
+ # Reshape and convert back to image
129
+ result_array = result_pixels.reshape(height, width, 3).astype(np.uint8)
130
+ return Image.fromarray(result_array)
131
+
132
+ @staticmethod
133
+ def rgb_to_hex(rgb: Tuple[int, int, int]) -> str:
134
+ """Convert RGB tuple to hex string"""
135
+ return '#{:02x}{:02x}{:02x}'.format(*rgb)
136
+
137
+ @staticmethod
138
+ def hex_to_rgb(hex_color: str) -> Tuple[int, int, int]:
139
+ """Convert hex string to RGB tuple"""
140
+ hex_color = hex_color.lstrip('#')
141
+ return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
142
+
143
+
144
+ class ColorTheory:
145
+ """
146
+ Color harmony generation using color theory
147
+ Implements patterns from colorharmonies library
148
+ """
149
+
150
+ @staticmethod
151
+ def complementary(base_color: Tuple[int, int, int]) -> List[Tuple[int, int, int]]:
152
+ """
153
+ Generate complementary palette (opposite on color wheel)
154
+ 180° rotation in HSV space
155
+ """
156
+ r, g, b = [c / 255.0 for c in base_color]
157
+ h, s, v = colorsys.rgb_to_hsv(r, g, b)
158
+
159
+ colors = [base_color]
160
+
161
+ # Complementary: rotate hue by 180°
162
+ comp_h = (h + 0.5) % 1.0
163
+ comp_rgb = colorsys.hsv_to_rgb(comp_h, s, v)
164
+ colors.append(tuple(int(c * 255) for c in comp_rgb))
165
+
166
+ return colors
167
+
168
+ @staticmethod
169
+ def analogous(base_color: Tuple[int, int, int]) -> List[Tuple[int, int, int]]:
170
+ """
171
+ Generate analogous palette (adjacent on color wheel)
172
+ ±30° rotation in HSV space
173
+ """
174
+ r, g, b = [c / 255.0 for c in base_color]
175
+ h, s, v = colorsys.rgb_to_hsv(r, g, b)
176
+
177
+ colors = []
178
+
179
+ # Three analogous colors: -30°, base, +30°
180
+ for offset in [-1/12, 0, 1/12]: # ±30° = ±1/12 of circle
181
+ new_h = (h + offset) % 1.0
182
+ new_rgb = colorsys.hsv_to_rgb(new_h, s, v)
183
+ colors.append(tuple(int(c * 255) for c in new_rgb))
184
+
185
+ return colors
186
+
187
+ @staticmethod
188
+ def triadic(base_color: Tuple[int, int, int]) -> List[Tuple[int, int, int]]:
189
+ """
190
+ Generate triadic palette (evenly spaced on color wheel)
191
+ 120° spacing in HSV space
192
+ """
193
+ r, g, b = [c / 255.0 for c in base_color]
194
+ h, s, v = colorsys.rgb_to_hsv(r, g, b)
195
+
196
+ colors = []
197
+
198
+ # Three colors at 0°, 120°, 240°
199
+ for offset in [0, 1/3, 2/3]:
200
+ new_h = (h + offset) % 1.0
201
+ new_rgb = colorsys.hsv_to_rgb(new_h, s, v)
202
+ colors.append(tuple(int(c * 255) for c in new_rgb))
203
+
204
+ return colors
205
+
206
+ @staticmethod
207
+ def split_complementary(base_color: Tuple[int, int, int]) -> List[Tuple[int, int, int]]:
208
+ """
209
+ Generate split-complementary palette
210
+ Base + two colors adjacent to complement (±150°)
211
+ """
212
+ r, g, b = [c / 255.0 for c in base_color]
213
+ h, s, v = colorsys.rgb_to_hsv(r, g, b)
214
+
215
+ colors = [base_color]
216
+
217
+ # Two colors at ±150° from base (30° on each side of complement)
218
+ for offset in [5/12, 7/12]: # 150° and 210° = 5/12 and 7/12 of circle
219
+ new_h = (h + offset) % 1.0
220
+ new_rgb = colorsys.hsv_to_rgb(new_h, s, v)
221
+ colors.append(tuple(int(c * 255) for c in new_rgb))
222
+
223
+ return colors
224
+
225
+ @staticmethod
226
+ def tetradic(base_color: Tuple[int, int, int]) -> List[Tuple[int, int, int]]:
227
+ """
228
+ Generate tetradic/double-complementary palette
229
+ Two complementary pairs (90° spacing)
230
+ """
231
+ r, g, b = [c / 255.0 for c in base_color]
232
+ h, s, v = colorsys.rgb_to_hsv(r, g, b)
233
+
234
+ colors = []
235
+
236
+ # Four colors at 0°, 90°, 180°, 270°
237
+ for offset in [0, 0.25, 0.5, 0.75]:
238
+ new_h = (h + offset) % 1.0
239
+ new_rgb = colorsys.hsv_to_rgb(new_h, s, v)
240
+ colors.append(tuple(int(c * 255) for c in new_rgb))
241
+
242
+ return colors
243
+
244
+ @staticmethod
245
+ def monochromatic(base_color: Tuple[int, int, int], n_colors: int = 5) -> List[Tuple[int, int, int]]:
246
+ """
247
+ Generate monochromatic palette (variations of one hue)
248
+ Varies saturation and value while keeping hue constant
249
+ """
250
+ r, g, b = [c / 255.0 for c in base_color]
251
+ h, s, v = colorsys.rgb_to_hsv(r, g, b)
252
+
253
+ colors = []
254
+
255
+ for i in range(n_colors):
256
+ # Vary value (brightness) from dark to light
257
+ new_v = 0.2 + (i / (n_colors - 1)) * 0.7 # 20% to 90%
258
+ # Slightly vary saturation for visual interest
259
+ new_s = max(0.3, min(1.0, s + (i / n_colors - 0.5) * 0.3))
260
+ new_rgb = colorsys.hsv_to_rgb(h, new_s, new_v)
261
+ colors.append(tuple(int(c * 255) for c in new_rgb))
262
+
263
+ return colors
264
+
265
+
266
+ class PaletteVisualizer:
267
+ """Create visual representations of color palettes"""
268
+
269
+ @staticmethod
270
+ def create_palette_image(
271
+ palette: List[Tuple[int, int, int]],
272
+ width: int = 500,
273
+ height: int = 100,
274
+ show_hex: bool = False
275
+ ) -> Image.Image:
276
+ """
277
+ Create a visual swatch of colors
278
+ Pattern from Pylette visualization
279
+ """
280
+ n_colors = len(palette)
281
+ if n_colors == 0:
282
+ return Image.new('RGB', (width, height), (255, 255, 255))
283
+
284
+ color_width = width // n_colors
285
+
286
+ palette_img = Image.new('RGB', (width, height))
287
+ pixels = palette_img.load()
288
+
289
+ for i, color in enumerate(palette):
290
+ x_start = i * color_width
291
+ x_end = (i + 1) * color_width if i < n_colors - 1 else width
292
+
293
+ for x in range(x_start, x_end):
294
+ for y in range(height):
295
+ pixels[x, y] = color
296
+
297
+ return palette_img
298
+
299
+ @staticmethod
300
+ def create_comparison_image(
301
+ original: Image.Image,
302
+ modified: Image.Image,
303
+ original_palette: List[Tuple[int, int, int]],
304
+ target_palette: List[Tuple[int, int, int]]
305
+ ) -> Image.Image:
306
+ """Create side-by-side comparison with palettes"""
307
+ # Create palette swatches
308
+ swatch_height = 60
309
+ original_swatch = PaletteVisualizer.create_palette_image(
310
+ original_palette, original.width, swatch_height
311
+ )
312
+ target_swatch = PaletteVisualizer.create_palette_image(
313
+ target_palette, modified.width, swatch_height
314
+ )
315
+
316
+ # Calculate total dimensions
317
+ total_width = original.width + modified.width
318
+ total_height = max(original.height, modified.height) + swatch_height
319
+
320
+ # Create combined image
321
+ result = Image.new('RGB', (total_width, total_height), (255, 255, 255))
322
+
323
+ # Paste images
324
+ result.paste(original, (0, 0))
325
+ result.paste(modified, (original.width, 0))
326
+ result.paste(original_swatch, (0, original.height))
327
+ result.paste(target_swatch, (original.width, modified.height))
328
+
329
+ return result
image_generator.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AI Image Generation Module
3
+ Uses Stable Diffusion to generate images from text prompts
4
+ """
5
+
6
+ import torch
7
+ from diffusers import StableDiffusionPipeline
8
+ from PIL import Image
9
+ import os
10
+
11
+
12
+ class ImageGenerator:
13
+ def __init__(self, model_id="stabilityai/stable-diffusion-2-1-base"):
14
+ """Initialize the Stable Diffusion pipeline"""
15
+ self.model_id = model_id
16
+ self.pipe = None
17
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ def load_model(self):
20
+ """Load the Stable Diffusion model"""
21
+ if self.pipe is None:
22
+ print(f"Loading model on {self.device}...")
23
+ self.pipe = StableDiffusionPipeline.from_pretrained(
24
+ self.model_id,
25
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
26
+ safety_checker=None
27
+ )
28
+ self.pipe = self.pipe.to(self.device)
29
+
30
+ # Enable memory optimizations
31
+ if self.device == "cuda":
32
+ self.pipe.enable_attention_slicing()
33
+
34
+ print("Model loaded successfully!")
35
+
36
+ def generate(
37
+ self,
38
+ prompt: str,
39
+ negative_prompt: str = "",
40
+ num_inference_steps: int = 30,
41
+ guidance_scale: float = 7.5,
42
+ width: int = 512,
43
+ height: int = 512,
44
+ seed: int = None
45
+ ) -> Image.Image:
46
+ """
47
+ Generate an image from a text prompt
48
+
49
+ Args:
50
+ prompt: Text description of desired image
51
+ negative_prompt: What to avoid in the image
52
+ num_inference_steps: Number of denoising steps (higher = better quality, slower)
53
+ guidance_scale: How closely to follow the prompt (7-10 recommended)
54
+ width: Image width (must be multiple of 8)
55
+ height: Image height (must be multiple of 8)
56
+ seed: Random seed for reproducibility
57
+
58
+ Returns:
59
+ PIL Image
60
+ """
61
+ self.load_model()
62
+
63
+ # Set seed for reproducibility
64
+ generator = None
65
+ if seed is not None:
66
+ generator = torch.Generator(device=self.device).manual_seed(seed)
67
+
68
+ # Generate image
69
+ with torch.inference_mode():
70
+ result = self.pipe(
71
+ prompt=prompt,
72
+ negative_prompt=negative_prompt,
73
+ num_inference_steps=num_inference_steps,
74
+ guidance_scale=guidance_scale,
75
+ width=width,
76
+ height=height,
77
+ generator=generator
78
+ )
79
+
80
+ return result.images[0]
81
+
82
+ def unload_model(self):
83
+ """Free up memory by unloading the model"""
84
+ if self.pipe is not None:
85
+ del self.pipe
86
+ self.pipe = None
87
+ if torch.cuda.is_available():
88
+ torch.cuda.empty_cache()
89
+ print("Model unloaded")
90
+
91
+
92
+ # Example usage
93
+ if __name__ == "__main__":
94
+ generator = ImageGenerator()
95
+ image = generator.generate(
96
+ prompt="A fantasy landscape with mountains and a castle at sunset",
97
+ seed=42
98
+ )
99
+ image.save("test_generated.png")
100
+ print("Image saved as test_generated.png")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ diffusers==0.30.0
3
+ transformers==4.44.0
4
+ accelerate==0.33.0
5
+ torch==2.4.0
6
+ pillow==10.4.0
7
+ numpy==1.26.4
8
+ scikit-learn==1.5.1
9
+ scipy==1.14.0