Spaces:
Sleeping
Sleeping
| """ | |
| AI Image Editor - Single Page Palette Swapper | |
| Upload/generate images, extract palettes, swap colors with dropdown mapping | |
| Includes AI-powered editing with InstructPix2Pix | |
| """ | |
| import gradio as gr | |
| from PIL import Image | |
| from color_palette import ColorPalette, ColorTheory, PaletteVisualizer | |
| import os | |
| # Check if we're on HuggingFace (has GPU) or local | |
| HF_SPACE = os.getenv("SPACE_ID") is not None | |
| # Only import AI/GPU stuff if on HF Space | |
| pix2pix_pipe = None | |
| if HF_SPACE: | |
| import spaces | |
| import torch | |
| from diffusers import StableDiffusionInstructPix2PixPipeline | |
| # Store current state | |
| current_source_palette = [] | |
| current_source_proportions = [] | |
| def extract_palette_from_image(image, n_colors): | |
| """Extract palette when image is uploaded""" | |
| global current_source_palette, current_source_proportions | |
| if image is None: | |
| current_source_palette = [] | |
| current_source_proportions = [] | |
| return None, "", *[gr.update(choices=[], value=None) for _ in range(5)] | |
| # Extract with proportions | |
| colors, proportions = ColorPalette.extract_palette( | |
| image, n_colors=int(n_colors), return_proportions=True | |
| ) | |
| current_source_palette = colors | |
| current_source_proportions = proportions | |
| # Create visual palette | |
| palette_img = PaletteVisualizer.create_palette_image( | |
| colors, width=500, height=80, proportions=proportions | |
| ) | |
| # Create text description | |
| desc_lines = [] | |
| for i, (color, prop) in enumerate(zip(colors, proportions)): | |
| hex_code = ColorPalette.rgb_to_hex(color) | |
| desc_lines.append(f"Color {i+1}: {hex_code} ({prop*100:.1f}%)") | |
| palette_desc = "\n".join(desc_lines) | |
| # Create dropdown choices for mapping | |
| choices = [f"{i+1}: {ColorPalette.rgb_to_hex(c)} ({p*100:.0f}%)" | |
| for i, (c, p) in enumerate(zip(colors, proportions))] | |
| # Update dropdowns (show as many as we have colors) | |
| dropdown_updates = [] | |
| for i in range(5): | |
| if i < len(colors): | |
| dropdown_updates.append(gr.update(choices=choices, value=choices[i], visible=True)) | |
| else: | |
| dropdown_updates.append(gr.update(choices=[], value=None, visible=False)) | |
| return palette_img, palette_desc, *dropdown_updates | |
| def generate_target_palette(method, base_color, ref_image, n_colors): | |
| """ | |
| Generate target palette based on source palette and new base color. | |
| The new base color becomes the dominant color, and other colors are | |
| generated to match the tonal relationships of the source palette | |
| while applying the selected color harmony. | |
| """ | |
| n = int(n_colors) | |
| base_rgb = ColorPalette.hex_to_rgb(base_color) | |
| if method == "From Reference Image" and ref_image is not None: | |
| colors = ColorPalette.extract_palette(ref_image, n_colors=n) | |
| elif not current_source_palette: | |
| # No source image yet - fall back to simple generation | |
| colors = ColorTheory.monochromatic(base_rgb, n_colors=n) | |
| else: | |
| # Transform source palette using new base color + harmony type | |
| colors = ColorTheory.transform_palette( | |
| current_source_palette, | |
| base_rgb, | |
| method | |
| ) | |
| # Create visual (use source proportions if available for consistency) | |
| if current_source_proportions and len(current_source_proportions) == len(colors): | |
| palette_img = PaletteVisualizer.create_palette_image( | |
| colors, width=500, height=80, proportions=current_source_proportions | |
| ) | |
| else: | |
| palette_img = PaletteVisualizer.create_palette_image(colors, width=500, height=80) | |
| # Description | |
| desc = "\n".join([f"Color {i+1}: {ColorPalette.rgb_to_hex(c)}" for i, c in enumerate(colors)]) | |
| return palette_img, desc | |
| def apply_swap(source_image, target_method, base_color, ref_image, n_colors, | |
| map1, map2, map3, map4, map5, use_ai=False): | |
| """ | |
| Apply the palette swap with custom mapping. | |
| If use_ai is True and on HF Space, uses AI-enhanced swap where: | |
| - K-means provides color clusters and masks (the 'what' and 'where') | |
| - AI provides intelligent recoloring (preserves lighting/texture) | |
| Otherwise uses basic pixel-level swap. | |
| """ | |
| if source_image is None: | |
| return None | |
| if not current_source_palette: | |
| return None | |
| # Generate target palette using same logic as generate_target_palette | |
| n = int(n_colors) | |
| base_rgb = ColorPalette.hex_to_rgb(base_color) | |
| if target_method == "From Reference Image" and ref_image is not None: | |
| target_colors = ColorPalette.extract_palette(ref_image, n_colors=n) | |
| else: | |
| # Transform source palette using new base color + harmony type | |
| target_colors = ColorTheory.transform_palette( | |
| current_source_palette, | |
| base_rgb, | |
| target_method | |
| ) | |
| # Build mapping from dropdowns | |
| mapping = {} | |
| dropdown_values = [map1, map2, map3, map4, map5] | |
| for src_idx, dropdown_val in enumerate(dropdown_values[:len(current_source_palette)]): | |
| if dropdown_val: | |
| # Extract target index from dropdown value like "2: #FF0000 (30%)" | |
| try: | |
| tgt_idx = int(dropdown_val.split(":")[0]) - 1 | |
| mapping[src_idx] = tgt_idx | |
| except: | |
| mapping[src_idx] = src_idx # Default to same position | |
| # Try AI-enhanced swap if enabled | |
| if use_ai and HF_SPACE: | |
| result = ai_apply_swap( | |
| source_image, | |
| current_source_palette, | |
| target_colors, | |
| mapping | |
| ) | |
| if result is not None: | |
| return result | |
| # Fall back to basic swap if AI fails | |
| print("AI swap failed, falling back to basic swap") | |
| # Basic pixel-level swap (uses K-means distance) | |
| result = ColorPalette.swap_palette_mapped( | |
| source_image, | |
| current_source_palette, | |
| target_colors, | |
| mapping | |
| ) | |
| return result | |
| def generate_color_mask(image, source_palette, color_index): | |
| """ | |
| Generate a mask for pixels belonging to a specific K-means color cluster. | |
| This provides the 'where' context for the AI - which regions to recolor. | |
| """ | |
| import numpy as np | |
| img_array = np.array(image.convert('RGB')).astype(float) | |
| height, width, _ = img_array.shape | |
| pixels = img_array.reshape(-1, 3) | |
| # Calculate distance to all source colors | |
| all_distances = np.array([ | |
| np.linalg.norm(pixels - np.array(src), axis=1) | |
| for src in source_palette | |
| ]) | |
| closest_color_idx = np.argmin(all_distances, axis=0) | |
| # Create mask where this color is dominant | |
| mask = (closest_color_idx == color_index).reshape(height, width) | |
| # Convert to PIL Image (white = edit, black = keep) | |
| mask_img = Image.fromarray((mask * 255).astype(np.uint8), mode='L') | |
| return mask_img | |
| def describe_color(rgb): | |
| """Convert RGB to human-readable color name for AI prompts""" | |
| import colorsys | |
| r, g, b = [c / 255.0 for c in rgb] | |
| h, s, v = colorsys.rgb_to_hsv(r, g, b) | |
| # Handle grayscale | |
| if s < 0.15: | |
| if v > 0.85: | |
| return "white" | |
| elif v < 0.15: | |
| return "black" | |
| elif v > 0.6: | |
| return "light gray" | |
| elif v < 0.4: | |
| return "dark gray" | |
| return "gray" | |
| # Lightness modifier | |
| if v < 0.35: | |
| lightness = "dark " | |
| elif v > 0.75 and s < 0.5: | |
| lightness = "light " | |
| else: | |
| lightness = "" | |
| # Hue name | |
| hue_deg = h * 360 | |
| if hue_deg < 15 or hue_deg >= 345: | |
| hue_name = "red" | |
| elif hue_deg < 45: | |
| hue_name = "orange" | |
| elif hue_deg < 70: | |
| hue_name = "yellow" | |
| elif hue_deg < 150: | |
| hue_name = "green" | |
| elif hue_deg < 200: | |
| hue_name = "cyan" | |
| elif hue_deg < 260: | |
| hue_name = "blue" | |
| elif hue_deg < 290: | |
| hue_name = "purple" | |
| else: | |
| hue_name = "pink" | |
| return f"{lightness}{hue_name}" | |
| if HF_SPACE: | |
| def ai_apply_swap(image, source_palette, target_palette, mapping, steps=20): | |
| """ | |
| AI-enhanced palette swap using K-means data as context. | |
| Your K-means algorithm provides: | |
| - The 'what': which color clusters exist | |
| - The 'mask': which pixels belong to each cluster | |
| The AI provides: | |
| - Intelligent recoloring that preserves lighting/texture | |
| - Natural transitions between colors | |
| - Context-aware editing (understands objects) | |
| Note: Model is loaded inside @spaces.GPU so it runs on allocated GPU. | |
| """ | |
| global pix2pix_pipe | |
| if image is None or not source_palette or not target_palette: | |
| return None | |
| try: | |
| # Load model inside GPU context (ZeroGPU requirement) | |
| if pix2pix_pipe is None: | |
| print("Loading InstructPix2Pix model...") | |
| pix2pix_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( | |
| "timbrooks/instruct-pix2pix", | |
| torch_dtype=torch.float16, | |
| safety_checker=None | |
| ) | |
| pix2pix_pipe.to("cuda") | |
| print("Model loaded successfully") | |
| original_size = image.size | |
| working_image = image.convert("RGB").resize((512, 512)) | |
| result = working_image.copy() | |
| # Process each color mapping | |
| for src_idx, tgt_idx in mapping.items(): | |
| if src_idx >= len(source_palette) or tgt_idx >= len(target_palette): | |
| continue | |
| src_color = source_palette[src_idx] | |
| tgt_color = target_palette[tgt_idx] | |
| # Skip if colors are very similar | |
| color_diff = sum(abs(s - t) for s, t in zip(src_color, tgt_color)) | |
| if color_diff < 30: | |
| continue | |
| # Generate mask from K-means cluster (your algorithm's context) | |
| mask = generate_color_mask(result, source_palette, src_idx) | |
| mask = mask.resize((512, 512)) | |
| # Build natural language instruction | |
| src_desc = describe_color(src_color) | |
| tgt_desc = describe_color(tgt_color) | |
| tgt_hex = ColorPalette.rgb_to_hex(tgt_color) | |
| instruction = f"Change the {src_desc} areas to {tgt_desc} ({tgt_hex}), preserve lighting and texture" | |
| print(f"AI Instruction: {instruction}") | |
| # Apply AI edit (the AI uses your K-means mask as context) | |
| edited = pix2pix_pipe( | |
| instruction, | |
| image=result, | |
| num_inference_steps=int(steps), | |
| image_guidance_scale=1.5, | |
| guidance_scale=7.5, | |
| ).images[0] | |
| result = edited | |
| # Resize back to original | |
| result = result.resize(original_size, Image.Resampling.LANCZOS) | |
| return result | |
| except Exception as e: | |
| print(f"AI Edit error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| else: | |
| def ai_apply_swap(image, source_palette, target_palette, mapping, steps=20): | |
| """Fallback for local - returns None so basic swap is used""" | |
| return None | |
| # Build the UI | |
| with gr.Blocks(title="AI Image Editor", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π¨ AI Image Editor - Palette Swapper") | |
| gr.Markdown("Upload an image, extract its color palette, then swap colors using color theory or a reference image.") | |
| with gr.Row(): | |
| # LEFT COLUMN - Source Image | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π· Source Image") | |
| source_image = gr.Image(label="Upload Image", type="pil", height=300) | |
| # AI Generation (collapsible, only shown on HF) | |
| if HF_SPACE: | |
| with gr.Accordion("β¨ Or Generate with AI", open=False): | |
| ai_prompt = gr.Textbox(label="Prompt", placeholder="A fantasy landscape...") | |
| ai_negative = gr.Textbox(label="Negative Prompt", value="blurry, bad quality") | |
| with gr.Row(): | |
| ai_steps = gr.Slider(10, 50, value=25, step=1, label="Steps") | |
| ai_guidance = gr.Slider(1, 15, value=7.5, step=0.5, label="Guidance") | |
| ai_seed = gr.Number(label="Seed (-1 = random)", value=-1) | |
| ai_generate_btn = gr.Button("π¨ Generate", variant="secondary") | |
| n_colors = gr.Slider(2, 8, value=5, step=1, label="Number of Colors to Extract") | |
| gr.Markdown("### π¨ Extracted Palette") | |
| source_palette_img = gr.Image(label="Source Palette", height=80, interactive=False) | |
| source_palette_desc = gr.Textbox(label="Colors", lines=5, interactive=False) | |
| # RIGHT COLUMN - Target & Result | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π― Target Palette") | |
| target_method = gr.Dropdown( | |
| choices=[ | |
| "Color Harmony - Complementary", | |
| "Color Harmony - Analogous", | |
| "Color Harmony - Triadic", | |
| "Color Harmony - Split-Complementary", | |
| "Color Harmony - Tetradic", | |
| "Color Harmony - Monochromatic", | |
| "From Reference Image" | |
| ], | |
| value="Color Harmony - Complementary", | |
| label="Generate Target From" | |
| ) | |
| with gr.Row(): | |
| base_color = gr.ColorPicker(label="Base Color", value="#3498db") | |
| ref_image = gr.Image(label="Reference Image", type="pil", height=100) | |
| generate_palette_btn = gr.Button("Generate Target Palette", variant="secondary") | |
| target_palette_img = gr.Image(label="Target Palette", height=80, interactive=False) | |
| target_palette_desc = gr.Textbox(label="Target Colors", lines=3, interactive=False) | |
| gr.Markdown("### π Color Mapping") | |
| gr.Markdown("*Map each source color to a target color:*") | |
| with gr.Row(): | |
| map_dropdown_1 = gr.Dropdown(label="Source 1 β", visible=False) | |
| map_dropdown_2 = gr.Dropdown(label="Source 2 β", visible=False) | |
| with gr.Row(): | |
| map_dropdown_3 = gr.Dropdown(label="Source 3 β", visible=False) | |
| map_dropdown_4 = gr.Dropdown(label="Source 4 β", visible=False) | |
| with gr.Row(): | |
| map_dropdown_5 = gr.Dropdown(label="Source 5 β", visible=False) | |
| # AI toggle (only shown on HF Space) | |
| if HF_SPACE: | |
| use_ai_checkbox = gr.Checkbox( | |
| label="π€ Use AI-Enhanced Swap", | |
| value=True, | |
| info="AI uses your K-means palette as context to recolor naturally, preserving lighting and textures" | |
| ) | |
| else: | |
| use_ai_checkbox = gr.Checkbox(label="AI (HF Space only)", value=False, visible=False) | |
| apply_btn = gr.Button("π Apply Palette Swap", variant="primary", size="lg") | |
| gr.Markdown("### β Result") | |
| result_image = gr.Image(label="Result", height=300) | |
| # Event handlers | |
| source_image.change( | |
| fn=extract_palette_from_image, | |
| inputs=[source_image, n_colors], | |
| outputs=[source_palette_img, source_palette_desc, | |
| map_dropdown_1, map_dropdown_2, map_dropdown_3, | |
| map_dropdown_4, map_dropdown_5] | |
| ) | |
| n_colors.change( | |
| fn=extract_palette_from_image, | |
| inputs=[source_image, n_colors], | |
| outputs=[source_palette_img, source_palette_desc, | |
| map_dropdown_1, map_dropdown_2, map_dropdown_3, | |
| map_dropdown_4, map_dropdown_5] | |
| ) | |
| generate_palette_btn.click( | |
| fn=generate_target_palette, | |
| inputs=[target_method, base_color, ref_image, n_colors], | |
| outputs=[target_palette_img, target_palette_desc] | |
| ) | |
| apply_btn.click( | |
| fn=apply_swap, | |
| inputs=[source_image, target_method, base_color, ref_image, n_colors, | |
| map_dropdown_1, map_dropdown_2, map_dropdown_3, | |
| map_dropdown_4, map_dropdown_5, use_ai_checkbox], | |
| outputs=[result_image] | |
| ) | |
| gr.Markdown("---") | |
| if HF_SPACE: | |
| gr.Markdown("*Built with Gradio β’ K-means palette extraction β’ AI-enhanced recoloring with InstructPix2Pix*") | |
| else: | |
| gr.Markdown("*Built with Gradio β’ Color extraction via K-means clustering β’ Color theory harmonies*") | |
| if __name__ == "__main__": | |
| demo.launch() | |