ai-image-editor / app.py
official.ghost.logic
Fix ZeroGPU: load model inside @spaces.GPU decorated function
891a9c9
"""
AI Image Editor - Single Page Palette Swapper
Upload/generate images, extract palettes, swap colors with dropdown mapping
Includes AI-powered editing with InstructPix2Pix
"""
import gradio as gr
from PIL import Image
from color_palette import ColorPalette, ColorTheory, PaletteVisualizer
import os
# Check if we're on HuggingFace (has GPU) or local
HF_SPACE = os.getenv("SPACE_ID") is not None
# Only import AI/GPU stuff if on HF Space
pix2pix_pipe = None
if HF_SPACE:
import spaces
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline
# Store current state
current_source_palette = []
current_source_proportions = []
def extract_palette_from_image(image, n_colors):
"""Extract palette when image is uploaded"""
global current_source_palette, current_source_proportions
if image is None:
current_source_palette = []
current_source_proportions = []
return None, "", *[gr.update(choices=[], value=None) for _ in range(5)]
# Extract with proportions
colors, proportions = ColorPalette.extract_palette(
image, n_colors=int(n_colors), return_proportions=True
)
current_source_palette = colors
current_source_proportions = proportions
# Create visual palette
palette_img = PaletteVisualizer.create_palette_image(
colors, width=500, height=80, proportions=proportions
)
# Create text description
desc_lines = []
for i, (color, prop) in enumerate(zip(colors, proportions)):
hex_code = ColorPalette.rgb_to_hex(color)
desc_lines.append(f"Color {i+1}: {hex_code} ({prop*100:.1f}%)")
palette_desc = "\n".join(desc_lines)
# Create dropdown choices for mapping
choices = [f"{i+1}: {ColorPalette.rgb_to_hex(c)} ({p*100:.0f}%)"
for i, (c, p) in enumerate(zip(colors, proportions))]
# Update dropdowns (show as many as we have colors)
dropdown_updates = []
for i in range(5):
if i < len(colors):
dropdown_updates.append(gr.update(choices=choices, value=choices[i], visible=True))
else:
dropdown_updates.append(gr.update(choices=[], value=None, visible=False))
return palette_img, palette_desc, *dropdown_updates
def generate_target_palette(method, base_color, ref_image, n_colors):
"""
Generate target palette based on source palette and new base color.
The new base color becomes the dominant color, and other colors are
generated to match the tonal relationships of the source palette
while applying the selected color harmony.
"""
n = int(n_colors)
base_rgb = ColorPalette.hex_to_rgb(base_color)
if method == "From Reference Image" and ref_image is not None:
colors = ColorPalette.extract_palette(ref_image, n_colors=n)
elif not current_source_palette:
# No source image yet - fall back to simple generation
colors = ColorTheory.monochromatic(base_rgb, n_colors=n)
else:
# Transform source palette using new base color + harmony type
colors = ColorTheory.transform_palette(
current_source_palette,
base_rgb,
method
)
# Create visual (use source proportions if available for consistency)
if current_source_proportions and len(current_source_proportions) == len(colors):
palette_img = PaletteVisualizer.create_palette_image(
colors, width=500, height=80, proportions=current_source_proportions
)
else:
palette_img = PaletteVisualizer.create_palette_image(colors, width=500, height=80)
# Description
desc = "\n".join([f"Color {i+1}: {ColorPalette.rgb_to_hex(c)}" for i, c in enumerate(colors)])
return palette_img, desc
def apply_swap(source_image, target_method, base_color, ref_image, n_colors,
map1, map2, map3, map4, map5, use_ai=False):
"""
Apply the palette swap with custom mapping.
If use_ai is True and on HF Space, uses AI-enhanced swap where:
- K-means provides color clusters and masks (the 'what' and 'where')
- AI provides intelligent recoloring (preserves lighting/texture)
Otherwise uses basic pixel-level swap.
"""
if source_image is None:
return None
if not current_source_palette:
return None
# Generate target palette using same logic as generate_target_palette
n = int(n_colors)
base_rgb = ColorPalette.hex_to_rgb(base_color)
if target_method == "From Reference Image" and ref_image is not None:
target_colors = ColorPalette.extract_palette(ref_image, n_colors=n)
else:
# Transform source palette using new base color + harmony type
target_colors = ColorTheory.transform_palette(
current_source_palette,
base_rgb,
target_method
)
# Build mapping from dropdowns
mapping = {}
dropdown_values = [map1, map2, map3, map4, map5]
for src_idx, dropdown_val in enumerate(dropdown_values[:len(current_source_palette)]):
if dropdown_val:
# Extract target index from dropdown value like "2: #FF0000 (30%)"
try:
tgt_idx = int(dropdown_val.split(":")[0]) - 1
mapping[src_idx] = tgt_idx
except:
mapping[src_idx] = src_idx # Default to same position
# Try AI-enhanced swap if enabled
if use_ai and HF_SPACE:
result = ai_apply_swap(
source_image,
current_source_palette,
target_colors,
mapping
)
if result is not None:
return result
# Fall back to basic swap if AI fails
print("AI swap failed, falling back to basic swap")
# Basic pixel-level swap (uses K-means distance)
result = ColorPalette.swap_palette_mapped(
source_image,
current_source_palette,
target_colors,
mapping
)
return result
def generate_color_mask(image, source_palette, color_index):
"""
Generate a mask for pixels belonging to a specific K-means color cluster.
This provides the 'where' context for the AI - which regions to recolor.
"""
import numpy as np
img_array = np.array(image.convert('RGB')).astype(float)
height, width, _ = img_array.shape
pixels = img_array.reshape(-1, 3)
# Calculate distance to all source colors
all_distances = np.array([
np.linalg.norm(pixels - np.array(src), axis=1)
for src in source_palette
])
closest_color_idx = np.argmin(all_distances, axis=0)
# Create mask where this color is dominant
mask = (closest_color_idx == color_index).reshape(height, width)
# Convert to PIL Image (white = edit, black = keep)
mask_img = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
return mask_img
def describe_color(rgb):
"""Convert RGB to human-readable color name for AI prompts"""
import colorsys
r, g, b = [c / 255.0 for c in rgb]
h, s, v = colorsys.rgb_to_hsv(r, g, b)
# Handle grayscale
if s < 0.15:
if v > 0.85:
return "white"
elif v < 0.15:
return "black"
elif v > 0.6:
return "light gray"
elif v < 0.4:
return "dark gray"
return "gray"
# Lightness modifier
if v < 0.35:
lightness = "dark "
elif v > 0.75 and s < 0.5:
lightness = "light "
else:
lightness = ""
# Hue name
hue_deg = h * 360
if hue_deg < 15 or hue_deg >= 345:
hue_name = "red"
elif hue_deg < 45:
hue_name = "orange"
elif hue_deg < 70:
hue_name = "yellow"
elif hue_deg < 150:
hue_name = "green"
elif hue_deg < 200:
hue_name = "cyan"
elif hue_deg < 260:
hue_name = "blue"
elif hue_deg < 290:
hue_name = "purple"
else:
hue_name = "pink"
return f"{lightness}{hue_name}"
if HF_SPACE:
@spaces.GPU(duration=120)
def ai_apply_swap(image, source_palette, target_palette, mapping, steps=20):
"""
AI-enhanced palette swap using K-means data as context.
Your K-means algorithm provides:
- The 'what': which color clusters exist
- The 'mask': which pixels belong to each cluster
The AI provides:
- Intelligent recoloring that preserves lighting/texture
- Natural transitions between colors
- Context-aware editing (understands objects)
Note: Model is loaded inside @spaces.GPU so it runs on allocated GPU.
"""
global pix2pix_pipe
if image is None or not source_palette or not target_palette:
return None
try:
# Load model inside GPU context (ZeroGPU requirement)
if pix2pix_pipe is None:
print("Loading InstructPix2Pix model...")
pix2pix_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix",
torch_dtype=torch.float16,
safety_checker=None
)
pix2pix_pipe.to("cuda")
print("Model loaded successfully")
original_size = image.size
working_image = image.convert("RGB").resize((512, 512))
result = working_image.copy()
# Process each color mapping
for src_idx, tgt_idx in mapping.items():
if src_idx >= len(source_palette) or tgt_idx >= len(target_palette):
continue
src_color = source_palette[src_idx]
tgt_color = target_palette[tgt_idx]
# Skip if colors are very similar
color_diff = sum(abs(s - t) for s, t in zip(src_color, tgt_color))
if color_diff < 30:
continue
# Generate mask from K-means cluster (your algorithm's context)
mask = generate_color_mask(result, source_palette, src_idx)
mask = mask.resize((512, 512))
# Build natural language instruction
src_desc = describe_color(src_color)
tgt_desc = describe_color(tgt_color)
tgt_hex = ColorPalette.rgb_to_hex(tgt_color)
instruction = f"Change the {src_desc} areas to {tgt_desc} ({tgt_hex}), preserve lighting and texture"
print(f"AI Instruction: {instruction}")
# Apply AI edit (the AI uses your K-means mask as context)
edited = pix2pix_pipe(
instruction,
image=result,
num_inference_steps=int(steps),
image_guidance_scale=1.5,
guidance_scale=7.5,
).images[0]
result = edited
# Resize back to original
result = result.resize(original_size, Image.Resampling.LANCZOS)
return result
except Exception as e:
print(f"AI Edit error: {e}")
import traceback
traceback.print_exc()
return None
else:
def ai_apply_swap(image, source_palette, target_palette, mapping, steps=20):
"""Fallback for local - returns None so basic swap is used"""
return None
# Build the UI
with gr.Blocks(title="AI Image Editor", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 AI Image Editor - Palette Swapper")
gr.Markdown("Upload an image, extract its color palette, then swap colors using color theory or a reference image.")
with gr.Row():
# LEFT COLUMN - Source Image
with gr.Column(scale=1):
gr.Markdown("### πŸ“· Source Image")
source_image = gr.Image(label="Upload Image", type="pil", height=300)
# AI Generation (collapsible, only shown on HF)
if HF_SPACE:
with gr.Accordion("✨ Or Generate with AI", open=False):
ai_prompt = gr.Textbox(label="Prompt", placeholder="A fantasy landscape...")
ai_negative = gr.Textbox(label="Negative Prompt", value="blurry, bad quality")
with gr.Row():
ai_steps = gr.Slider(10, 50, value=25, step=1, label="Steps")
ai_guidance = gr.Slider(1, 15, value=7.5, step=0.5, label="Guidance")
ai_seed = gr.Number(label="Seed (-1 = random)", value=-1)
ai_generate_btn = gr.Button("🎨 Generate", variant="secondary")
n_colors = gr.Slider(2, 8, value=5, step=1, label="Number of Colors to Extract")
gr.Markdown("### 🎨 Extracted Palette")
source_palette_img = gr.Image(label="Source Palette", height=80, interactive=False)
source_palette_desc = gr.Textbox(label="Colors", lines=5, interactive=False)
# RIGHT COLUMN - Target & Result
with gr.Column(scale=1):
gr.Markdown("### 🎯 Target Palette")
target_method = gr.Dropdown(
choices=[
"Color Harmony - Complementary",
"Color Harmony - Analogous",
"Color Harmony - Triadic",
"Color Harmony - Split-Complementary",
"Color Harmony - Tetradic",
"Color Harmony - Monochromatic",
"From Reference Image"
],
value="Color Harmony - Complementary",
label="Generate Target From"
)
with gr.Row():
base_color = gr.ColorPicker(label="Base Color", value="#3498db")
ref_image = gr.Image(label="Reference Image", type="pil", height=100)
generate_palette_btn = gr.Button("Generate Target Palette", variant="secondary")
target_palette_img = gr.Image(label="Target Palette", height=80, interactive=False)
target_palette_desc = gr.Textbox(label="Target Colors", lines=3, interactive=False)
gr.Markdown("### πŸ”„ Color Mapping")
gr.Markdown("*Map each source color to a target color:*")
with gr.Row():
map_dropdown_1 = gr.Dropdown(label="Source 1 β†’", visible=False)
map_dropdown_2 = gr.Dropdown(label="Source 2 β†’", visible=False)
with gr.Row():
map_dropdown_3 = gr.Dropdown(label="Source 3 β†’", visible=False)
map_dropdown_4 = gr.Dropdown(label="Source 4 β†’", visible=False)
with gr.Row():
map_dropdown_5 = gr.Dropdown(label="Source 5 β†’", visible=False)
# AI toggle (only shown on HF Space)
if HF_SPACE:
use_ai_checkbox = gr.Checkbox(
label="πŸ€– Use AI-Enhanced Swap",
value=True,
info="AI uses your K-means palette as context to recolor naturally, preserving lighting and textures"
)
else:
use_ai_checkbox = gr.Checkbox(label="AI (HF Space only)", value=False, visible=False)
apply_btn = gr.Button("πŸ”„ Apply Palette Swap", variant="primary", size="lg")
gr.Markdown("### βœ… Result")
result_image = gr.Image(label="Result", height=300)
# Event handlers
source_image.change(
fn=extract_palette_from_image,
inputs=[source_image, n_colors],
outputs=[source_palette_img, source_palette_desc,
map_dropdown_1, map_dropdown_2, map_dropdown_3,
map_dropdown_4, map_dropdown_5]
)
n_colors.change(
fn=extract_palette_from_image,
inputs=[source_image, n_colors],
outputs=[source_palette_img, source_palette_desc,
map_dropdown_1, map_dropdown_2, map_dropdown_3,
map_dropdown_4, map_dropdown_5]
)
generate_palette_btn.click(
fn=generate_target_palette,
inputs=[target_method, base_color, ref_image, n_colors],
outputs=[target_palette_img, target_palette_desc]
)
apply_btn.click(
fn=apply_swap,
inputs=[source_image, target_method, base_color, ref_image, n_colors,
map_dropdown_1, map_dropdown_2, map_dropdown_3,
map_dropdown_4, map_dropdown_5, use_ai_checkbox],
outputs=[result_image]
)
gr.Markdown("---")
if HF_SPACE:
gr.Markdown("*Built with Gradio β€’ K-means palette extraction β€’ AI-enhanced recoloring with InstructPix2Pix*")
else:
gr.Markdown("*Built with Gradio β€’ Color extraction via K-means clustering β€’ Color theory harmonies*")
if __name__ == "__main__":
demo.launch()