""" AI Image Editor - Single Page Palette Swapper Upload/generate images, extract palettes, swap colors with dropdown mapping Includes AI-powered editing with InstructPix2Pix """ import gradio as gr from PIL import Image from color_palette import ColorPalette, ColorTheory, PaletteVisualizer import os # Check if we're on HuggingFace (has GPU) or local HF_SPACE = os.getenv("SPACE_ID") is not None # Only import AI/GPU stuff if on HF Space pix2pix_pipe = None if HF_SPACE: import spaces import torch from diffusers import StableDiffusionInstructPix2PixPipeline # Store current state current_source_palette = [] current_source_proportions = [] def extract_palette_from_image(image, n_colors): """Extract palette when image is uploaded""" global current_source_palette, current_source_proportions if image is None: current_source_palette = [] current_source_proportions = [] return None, "", *[gr.update(choices=[], value=None) for _ in range(5)] # Extract with proportions colors, proportions = ColorPalette.extract_palette( image, n_colors=int(n_colors), return_proportions=True ) current_source_palette = colors current_source_proportions = proportions # Create visual palette palette_img = PaletteVisualizer.create_palette_image( colors, width=500, height=80, proportions=proportions ) # Create text description desc_lines = [] for i, (color, prop) in enumerate(zip(colors, proportions)): hex_code = ColorPalette.rgb_to_hex(color) desc_lines.append(f"Color {i+1}: {hex_code} ({prop*100:.1f}%)") palette_desc = "\n".join(desc_lines) # Create dropdown choices for mapping choices = [f"{i+1}: {ColorPalette.rgb_to_hex(c)} ({p*100:.0f}%)" for i, (c, p) in enumerate(zip(colors, proportions))] # Update dropdowns (show as many as we have colors) dropdown_updates = [] for i in range(5): if i < len(colors): dropdown_updates.append(gr.update(choices=choices, value=choices[i], visible=True)) else: dropdown_updates.append(gr.update(choices=[], value=None, visible=False)) return palette_img, palette_desc, *dropdown_updates def generate_target_palette(method, base_color, ref_image, n_colors): """ Generate target palette based on source palette and new base color. The new base color becomes the dominant color, and other colors are generated to match the tonal relationships of the source palette while applying the selected color harmony. """ n = int(n_colors) base_rgb = ColorPalette.hex_to_rgb(base_color) if method == "From Reference Image" and ref_image is not None: colors = ColorPalette.extract_palette(ref_image, n_colors=n) elif not current_source_palette: # No source image yet - fall back to simple generation colors = ColorTheory.monochromatic(base_rgb, n_colors=n) else: # Transform source palette using new base color + harmony type colors = ColorTheory.transform_palette( current_source_palette, base_rgb, method ) # Create visual (use source proportions if available for consistency) if current_source_proportions and len(current_source_proportions) == len(colors): palette_img = PaletteVisualizer.create_palette_image( colors, width=500, height=80, proportions=current_source_proportions ) else: palette_img = PaletteVisualizer.create_palette_image(colors, width=500, height=80) # Description desc = "\n".join([f"Color {i+1}: {ColorPalette.rgb_to_hex(c)}" for i, c in enumerate(colors)]) return palette_img, desc def apply_swap(source_image, target_method, base_color, ref_image, n_colors, map1, map2, map3, map4, map5, use_ai=False): """ Apply the palette swap with custom mapping. If use_ai is True and on HF Space, uses AI-enhanced swap where: - K-means provides color clusters and masks (the 'what' and 'where') - AI provides intelligent recoloring (preserves lighting/texture) Otherwise uses basic pixel-level swap. """ if source_image is None: return None if not current_source_palette: return None # Generate target palette using same logic as generate_target_palette n = int(n_colors) base_rgb = ColorPalette.hex_to_rgb(base_color) if target_method == "From Reference Image" and ref_image is not None: target_colors = ColorPalette.extract_palette(ref_image, n_colors=n) else: # Transform source palette using new base color + harmony type target_colors = ColorTheory.transform_palette( current_source_palette, base_rgb, target_method ) # Build mapping from dropdowns mapping = {} dropdown_values = [map1, map2, map3, map4, map5] for src_idx, dropdown_val in enumerate(dropdown_values[:len(current_source_palette)]): if dropdown_val: # Extract target index from dropdown value like "2: #FF0000 (30%)" try: tgt_idx = int(dropdown_val.split(":")[0]) - 1 mapping[src_idx] = tgt_idx except: mapping[src_idx] = src_idx # Default to same position # Try AI-enhanced swap if enabled if use_ai and HF_SPACE: result = ai_apply_swap( source_image, current_source_palette, target_colors, mapping ) if result is not None: return result # Fall back to basic swap if AI fails print("AI swap failed, falling back to basic swap") # Basic pixel-level swap (uses K-means distance) result = ColorPalette.swap_palette_mapped( source_image, current_source_palette, target_colors, mapping ) return result def generate_color_mask(image, source_palette, color_index): """ Generate a mask for pixels belonging to a specific K-means color cluster. This provides the 'where' context for the AI - which regions to recolor. """ import numpy as np img_array = np.array(image.convert('RGB')).astype(float) height, width, _ = img_array.shape pixels = img_array.reshape(-1, 3) # Calculate distance to all source colors all_distances = np.array([ np.linalg.norm(pixels - np.array(src), axis=1) for src in source_palette ]) closest_color_idx = np.argmin(all_distances, axis=0) # Create mask where this color is dominant mask = (closest_color_idx == color_index).reshape(height, width) # Convert to PIL Image (white = edit, black = keep) mask_img = Image.fromarray((mask * 255).astype(np.uint8), mode='L') return mask_img def describe_color(rgb): """Convert RGB to human-readable color name for AI prompts""" import colorsys r, g, b = [c / 255.0 for c in rgb] h, s, v = colorsys.rgb_to_hsv(r, g, b) # Handle grayscale if s < 0.15: if v > 0.85: return "white" elif v < 0.15: return "black" elif v > 0.6: return "light gray" elif v < 0.4: return "dark gray" return "gray" # Lightness modifier if v < 0.35: lightness = "dark " elif v > 0.75 and s < 0.5: lightness = "light " else: lightness = "" # Hue name hue_deg = h * 360 if hue_deg < 15 or hue_deg >= 345: hue_name = "red" elif hue_deg < 45: hue_name = "orange" elif hue_deg < 70: hue_name = "yellow" elif hue_deg < 150: hue_name = "green" elif hue_deg < 200: hue_name = "cyan" elif hue_deg < 260: hue_name = "blue" elif hue_deg < 290: hue_name = "purple" else: hue_name = "pink" return f"{lightness}{hue_name}" if HF_SPACE: @spaces.GPU(duration=120) def ai_apply_swap(image, source_palette, target_palette, mapping, steps=20): """ AI-enhanced palette swap using K-means data as context. Your K-means algorithm provides: - The 'what': which color clusters exist - The 'mask': which pixels belong to each cluster The AI provides: - Intelligent recoloring that preserves lighting/texture - Natural transitions between colors - Context-aware editing (understands objects) Note: Model is loaded inside @spaces.GPU so it runs on allocated GPU. """ global pix2pix_pipe if image is None or not source_palette or not target_palette: return None try: # Load model inside GPU context (ZeroGPU requirement) if pix2pix_pipe is None: print("Loading InstructPix2Pix model...") pix2pix_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( "timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None ) pix2pix_pipe.to("cuda") print("Model loaded successfully") original_size = image.size working_image = image.convert("RGB").resize((512, 512)) result = working_image.copy() # Process each color mapping for src_idx, tgt_idx in mapping.items(): if src_idx >= len(source_palette) or tgt_idx >= len(target_palette): continue src_color = source_palette[src_idx] tgt_color = target_palette[tgt_idx] # Skip if colors are very similar color_diff = sum(abs(s - t) for s, t in zip(src_color, tgt_color)) if color_diff < 30: continue # Generate mask from K-means cluster (your algorithm's context) mask = generate_color_mask(result, source_palette, src_idx) mask = mask.resize((512, 512)) # Build natural language instruction src_desc = describe_color(src_color) tgt_desc = describe_color(tgt_color) tgt_hex = ColorPalette.rgb_to_hex(tgt_color) instruction = f"Change the {src_desc} areas to {tgt_desc} ({tgt_hex}), preserve lighting and texture" print(f"AI Instruction: {instruction}") # Apply AI edit (the AI uses your K-means mask as context) edited = pix2pix_pipe( instruction, image=result, num_inference_steps=int(steps), image_guidance_scale=1.5, guidance_scale=7.5, ).images[0] result = edited # Resize back to original result = result.resize(original_size, Image.Resampling.LANCZOS) return result except Exception as e: print(f"AI Edit error: {e}") import traceback traceback.print_exc() return None else: def ai_apply_swap(image, source_palette, target_palette, mapping, steps=20): """Fallback for local - returns None so basic swap is used""" return None # Build the UI with gr.Blocks(title="AI Image Editor", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎨 AI Image Editor - Palette Swapper") gr.Markdown("Upload an image, extract its color palette, then swap colors using color theory or a reference image.") with gr.Row(): # LEFT COLUMN - Source Image with gr.Column(scale=1): gr.Markdown("### 📷 Source Image") source_image = gr.Image(label="Upload Image", type="pil", height=300) # AI Generation (collapsible, only shown on HF) if HF_SPACE: with gr.Accordion("✨ Or Generate with AI", open=False): ai_prompt = gr.Textbox(label="Prompt", placeholder="A fantasy landscape...") ai_negative = gr.Textbox(label="Negative Prompt", value="blurry, bad quality") with gr.Row(): ai_steps = gr.Slider(10, 50, value=25, step=1, label="Steps") ai_guidance = gr.Slider(1, 15, value=7.5, step=0.5, label="Guidance") ai_seed = gr.Number(label="Seed (-1 = random)", value=-1) ai_generate_btn = gr.Button("🎨 Generate", variant="secondary") n_colors = gr.Slider(2, 8, value=5, step=1, label="Number of Colors to Extract") gr.Markdown("### 🎨 Extracted Palette") source_palette_img = gr.Image(label="Source Palette", height=80, interactive=False) source_palette_desc = gr.Textbox(label="Colors", lines=5, interactive=False) # RIGHT COLUMN - Target & Result with gr.Column(scale=1): gr.Markdown("### 🎯 Target Palette") target_method = gr.Dropdown( choices=[ "Color Harmony - Complementary", "Color Harmony - Analogous", "Color Harmony - Triadic", "Color Harmony - Split-Complementary", "Color Harmony - Tetradic", "Color Harmony - Monochromatic", "From Reference Image" ], value="Color Harmony - Complementary", label="Generate Target From" ) with gr.Row(): base_color = gr.ColorPicker(label="Base Color", value="#3498db") ref_image = gr.Image(label="Reference Image", type="pil", height=100) generate_palette_btn = gr.Button("Generate Target Palette", variant="secondary") target_palette_img = gr.Image(label="Target Palette", height=80, interactive=False) target_palette_desc = gr.Textbox(label="Target Colors", lines=3, interactive=False) gr.Markdown("### 🔄 Color Mapping") gr.Markdown("*Map each source color to a target color:*") with gr.Row(): map_dropdown_1 = gr.Dropdown(label="Source 1 →", visible=False) map_dropdown_2 = gr.Dropdown(label="Source 2 →", visible=False) with gr.Row(): map_dropdown_3 = gr.Dropdown(label="Source 3 →", visible=False) map_dropdown_4 = gr.Dropdown(label="Source 4 →", visible=False) with gr.Row(): map_dropdown_5 = gr.Dropdown(label="Source 5 →", visible=False) # AI toggle (only shown on HF Space) if HF_SPACE: use_ai_checkbox = gr.Checkbox( label="🤖 Use AI-Enhanced Swap", value=True, info="AI uses your K-means palette as context to recolor naturally, preserving lighting and textures" ) else: use_ai_checkbox = gr.Checkbox(label="AI (HF Space only)", value=False, visible=False) apply_btn = gr.Button("🔄 Apply Palette Swap", variant="primary", size="lg") gr.Markdown("### ✅ Result") result_image = gr.Image(label="Result", height=300) # Event handlers source_image.change( fn=extract_palette_from_image, inputs=[source_image, n_colors], outputs=[source_palette_img, source_palette_desc, map_dropdown_1, map_dropdown_2, map_dropdown_3, map_dropdown_4, map_dropdown_5] ) n_colors.change( fn=extract_palette_from_image, inputs=[source_image, n_colors], outputs=[source_palette_img, source_palette_desc, map_dropdown_1, map_dropdown_2, map_dropdown_3, map_dropdown_4, map_dropdown_5] ) generate_palette_btn.click( fn=generate_target_palette, inputs=[target_method, base_color, ref_image, n_colors], outputs=[target_palette_img, target_palette_desc] ) apply_btn.click( fn=apply_swap, inputs=[source_image, target_method, base_color, ref_image, n_colors, map_dropdown_1, map_dropdown_2, map_dropdown_3, map_dropdown_4, map_dropdown_5, use_ai_checkbox], outputs=[result_image] ) gr.Markdown("---") if HF_SPACE: gr.Markdown("*Built with Gradio • K-means palette extraction • AI-enhanced recoloring with InstructPix2Pix*") else: gr.Markdown("*Built with Gradio • Color extraction via K-means clustering • Color theory harmonies*") if __name__ == "__main__": demo.launch()