import torch import gradio as gr import spaces from functools import lru_cache from diffusers import StableDiffusionXLPipeline # =============================== # 🩹 FIX for Gradio bug (bool schema issue) # =============================== import gradio_client.utils as gu # Monkey patch for "TypeError: argument of type 'bool' is not iterable" if not hasattr(gu, "_patched_json_schema_to_python_type"): orig_get_type = gu.get_type def safe_get_type(schema): # Ensure schema is always a dict before checking keys if not isinstance(schema, dict): return str(schema) return orig_get_type(schema) gu.get_type = safe_get_type gu._patched_json_schema_to_python_type = True # =============================== # 🎨 Model and Styles Configuration # =============================== color_book_lora_path = "artificialguybr/ColoringBookRedmond-V2" color_book_trigger = ", ColoringBookAF, Coloring Book" styles = { "Neonpunk": { "prompt": "neonpunk style, cyberpunk, vaporwave, neon, vibrant, stunningly beautiful, crisp, " "detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic", "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured" }, "Retro Cyberpunk": { "prompt": "retro cyberpunk, 80's inspired, synthwave, neon, vibrant, detailed, retro futurism", "negative_prompt": "modern, desaturated, black and white, realism, low contrast" }, "Dark Fantasy": { "prompt": "Dark Fantasy Art, dark, moody, dark fantasy style", "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, bright, sunny" }, "Double Exposure": { "prompt": "Double Exposure Style, double image ghost effect, image combination, double exposure style", "negative_prompt": "ugly, deformed, noisy, blurry, low contrast" }, "None": { "prompt": "8K", "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured" } } # =============================== # 🚀 Pipeline Loader (with caching) # =============================== @lru_cache(maxsize=1) def load_pipeline(use_lora: bool): """Load Stable Diffusion XL pipeline and optionally apply LoRA weights.""" pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, use_safetensors=True ) pipe.to("cpu") if use_lora: pipe.load_lora_weights(color_book_lora_path) return pipe # =============================== # 🎨 Image Generation Function # =============================== @spaces.GPU # ZeroGPU: allocate GPU only when generating def generate_image(prompt: str, style_name: str, use_lora: bool): """Generate an image using Stable Diffusion XL with optional LoRA fine-tuning.""" # Load cached pipeline pipeline = load_pipeline(use_lora) pipeline.to("cuda") # Retrieve style info style_prompt = styles.get(style_name, {}).get("prompt", "") negative_prompt = styles.get(style_name, {}).get("negative_prompt", "") # Add LoRA trigger if needed if use_lora: prompt += color_book_trigger # Generate image image = pipeline( prompt=prompt + " " + style_prompt, negative_prompt="blurred, ugly, watermark, low resolution, " + negative_prompt, num_inference_steps=20, guidance_scale=9.0 ).images[0] # Move model back to CPU to release GPU pipeline.to("cpu") return image # =============================== # 🌐 Gradio Interface (for Spaces) # =============================== interface = gr.Interface( fn=generate_image, inputs=[ gr.Textbox(label="Enter Your Prompt", placeholder="A cute lion"), gr.Dropdown(label="Select a Style", choices=list(styles.keys()), value="None"), gr.Checkbox(label="Use Coloring Book LoRA", value=False) ], outputs=gr.Image(label="Generated Image"), title="🎨 AI Coloring Book & Style Generator", description=( "Generate AI-powered art using Stable Diffusion XL on Hugging Face Spaces. " "Choose a style or enable a LoRA fine-tuned coloring book effect. " "This app dynamically allocates GPU (ZeroGPU) only during generation." ) ) # =============================== # 🏁 Launch App # =============================== if __name__ == "__main__": interface.launch()