Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import torch | |
| import spaces | |
| from PIL import Image | |
| from diffusers import QwenImageEditPipeline | |
| from diffusers.utils import is_xformers_available | |
| import os | |
| import re | |
| import gc | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| ############################# | |
| os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') | |
| os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1') | |
| # Model configuration | |
| REWRITER_MODEL = "Qwen/Qwen1.5-1.8B-Chat" | |
| rewriter_tokenizer = None | |
| rewriter_model = None | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Quantization configuration | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| def load_rewriter(): | |
| """Lazily load the prompt enhancement model""" | |
| global rewriter_tokenizer, rewriter_model | |
| if rewriter_tokenizer is None or rewriter_model is None: | |
| print("🔄 Loading enhancement model...") | |
| rewriter_tokenizer = AutoTokenizer.from_pretrained(REWRITER_MODEL) | |
| rewriter_model = AutoModelForCausalLM.from_pretrained( | |
| REWRITER_MODEL, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| quantization_config=bnb_config | |
| ) | |
| print("✅ Enhancement model loaded") | |
| SYSTEM_PROMPT_EDIT = ''' | |
| # Edit Instruction Rewriter | |
| You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image. | |
| ## 1. General Principles | |
| - Keep the rewritten instruction **concise** and clear. | |
| - Avoid contradictions, vagueness, or unachievable instructions. | |
| - Maintain the core logic of the original instruction; only enhance clarity and feasibility. | |
| - Ensure new added elements or modifications align with the image's original context and art style. | |
| ## 2. Task Types | |
| ### Add, Delete, Replace: | |
| - When the input is detailed, only refine grammar and clarity. | |
| - For vague instructions, infer minimal but sufficient details. | |
| - For replacement, use the format: `"Replace X with Y"`. | |
| ### Text Editing (e.g., text replacement): | |
| - Enclose text content in quotes, e.g., `Replace "abc" with "xyz"`. | |
| - Preserving the original structure and language—**do not translate** or alter style. | |
| ### Human Editing (e.g., change a person’s face/hair): | |
| - Preserve core visual identity (gender, ethnic features). | |
| - Describe expressions in subtle and natural terms. | |
| - Maintain key clothing or styling details unless explicitly replaced. | |
| ### Style Transformation: | |
| - If a style is specified, e.g., `Disco style`, rewrite it to encapsulate the essential visual traits. | |
| - Use a fixed template for **coloring/restoration**: | |
| `"Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"` | |
| if applicable. | |
| ## 4. Output Format | |
| Please provide the rewritten instruction in a clean `json` format as: | |
| { | |
| "Rewritten": "..." | |
| } | |
| ''' | |
| def polish_prompt(original_prompt: str) -> str: | |
| """Enhanced prompt rewriting using Qwen1.5-1.8B""" | |
| load_rewriter() | |
| # Format as Qwen chat with system prompt | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT_EDIT}, | |
| {"role": "user", "content": original_prompt} | |
| ] | |
| # Generate enhanced prompt | |
| text = rewriter_tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = rewriter_tokenizer(text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| generated_ids = rewriter_model.generate( | |
| **model_inputs, | |
| max_new_tokens=120, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| no_repeat_ngram_size=2 | |
| ) | |
| # Extract and clean response | |
| enhanced = rewriter_tokenizer.decode( | |
| generated_ids[0][model_inputs.input_ids.shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| # Clean possible artifacts | |
| enhanced = enhanced.strip() | |
| if enhanced.lower().startswith(("rewritten instruction:", "enhanced:", "output:")): | |
| enhanced = re.split(r':', enhanced, 1)[-1].strip() | |
| # Remove any quotes around the prompt if present | |
| if enhanced.startswith('"') and enhanced.endswith('"'): | |
| enhanced = enhanced[1:-1] | |
| return enhanced | |
| # Load main image editing pipeline | |
| pipe = QwenImageEditPipeline.from_pretrained( | |
| "Qwen/Qwen-Image-Edit", | |
| torch_dtype=dtype | |
| ).to(device) | |
| # Load LoRA weights for acceleration | |
| pipe.load_lora_weights( | |
| "lightx2v/Qwen-Image-Lightning", | |
| weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors" | |
| ) | |
| pipe.fuse_lora() | |
| if is_xformers_available(): | |
| pipe.enable_xformers_memory_efficient_attention() | |
| else: | |
| print("xformers not available") | |
| def unload_rewriter(): | |
| """Clear enhancement model from memory""" | |
| global rewriter_tokenizer, rewriter_model | |
| if rewriter_model: | |
| del rewriter_tokenizer, rewriter_model | |
| rewriter_tokenizer = None | |
| rewriter_model = None | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def infer( | |
| image, | |
| prompt, | |
| seed=42, | |
| randomize_seed=False, | |
| true_guidance_scale=4.0, | |
| num_inference_steps=8, | |
| rewrite_prompt=False, | |
| num_images_per_prompt=1, | |
| ): | |
| """Image editing endpoint with optimized prompt handling""" | |
| original_prompt = prompt | |
| prompt_info = "" | |
| # Handle prompt rewriting | |
| if rewrite_prompt: | |
| try: | |
| enhanced_instruction = polish_prompt(original_prompt) | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>" | |
| f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>" | |
| f"<p><strong>Original:</strong> {original_prompt}</p>" | |
| f"<p><strong>Enhanced:</strong> {enhanced_instruction}</p>" | |
| f"</div>" | |
| ) | |
| prompt = enhanced_instruction | |
| except Exception as e: | |
| gr.Warning(f"Prompt enhancement failed: {str(e)}") | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>" | |
| f"<h4 style='margin-top: 0;'>⚠️ Enhancement Not Applied</h4>" | |
| f"<p>Using original prompt. Error: {str(e)}</p>" | |
| f"</div>" | |
| ) | |
| else: | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:10px; border-radius:8px; background: #f8f9fa'>" | |
| f"<h4 style='margin-top: 0;'>📝 Original Prompt</h4>" | |
| f"<p>{original_prompt}</p>" | |
| f"</div>" | |
| ) | |
| # Free VRAM after enhancement | |
| unload_rewriter() | |
| # Set seed for reproducibility | |
| seed_val = seed | |
| if randomize_seed: | |
| seed_val = random.randint(0, 2**32 - 1) | |
| generator = torch.Generator(device=device).manual_seed(seed_val) | |
| try: | |
| # Generate images | |
| edited_images = pipe( | |
| image=image, | |
| prompt=prompt, | |
| negative_prompt=" ", | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| true_cfg_scale=true_guidance_scale, | |
| num_images_per_prompt=num_images_per_prompt | |
| ).images | |
| except Exception as e: | |
| gr.Error(f"Image generation failed: {str(e)}") | |
| prompt_info = ( | |
| f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #dd2c00; background: #fef5f5'>" | |
| f"<h4 style='margin-top: 0;'><strong>⚠️ Error:</strong> {str(e)}</h4>" | |
| f"</div>" | |
| ) | |
| return [], seed_val, prompt_info | |
| return edited_images, seed_val, prompt_info | |
| MAX_SEED = np.iinfo(np.int32).max | |
| examples = [ | |
| "Replace the cat with a friendly golden retriever. Make it look happier, and add more background details.", | |
| "Add text 'Qwen - AI for image editing' in Chinese at the bottom center with a small shadow.", | |
| "Change the style to 1970s vintage, add old photo effect, restore any scratches on the wall or window.", | |
| "Remove the blue sky and replace it with a dark night cityscape.", | |
| """Replace "Qwen" with "通义" in the Image. Ensure Chinese font is used and position it at top left.""" | |
| ] | |
| with gr.Blocks(title="Qwen Image Editor", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| <div style="text-align: center;"> | |
| <h1>⚡️ Qwen-Image-Edit Lightning</h1> | |
| <p>8-step image editing with local prompt enhancement | Powered by NVIDIA H200</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Input Column | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="pil") | |
| prompt = gr.Textbox(label="Edit Instruction", placeholder="e.g. Add a dog to the right side", lines=2) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| gr.Markdown("### Generation Parameters") | |
| with gr.Row(): | |
| seed = gr.Slider(label="Seed", min=0, max=MAX_SEED, step=1, value=42) | |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
| with gr.Row(): | |
| true_guidance_scale = gr.Slider( | |
| label="Guidance Scale", min=1.0, max=5.0, step=0.1, value=4.0 | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Inference Steps", min=4, max=16, step=1, value=8 | |
| ) | |
| num_images_per_prompt = gr.Slider( | |
| label="Output Images", min=1, max=4, step=1, value=1 | |
| ) | |
| rewrite_toggle = gr.Checkbox( | |
| label="Enable AI Prompt Enhancement", | |
| value=True, | |
| info="Uses local Qwen1.5-1.8B model to improve your instructions" | |
| ) | |
| run_button = gr.Button("Generate Edits", variant="primary") | |
| # Output Column | |
| with gr.Column(): | |
| result = gr.Gallery( | |
| label="Output Images", | |
| columns=lambda x: 2 if x > 1 else 1, | |
| object_fit="contain", | |
| height="auto" | |
| ) | |
| prompt_info = gr.HTML( | |
| "<div style='margin-top:20px; padding:15px; border-radius:8px; background:#f8f9fa'>" | |
| "<p>Prompt details will appear here after generation</p></div>" | |
| ) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[prompt], | |
| label="Try These Examples", | |
| cache_examples=True | |
| ) | |
| # Main processing handler | |
| inputs = [ | |
| input_image, | |
| prompt, | |
| seed, | |
| randomize_seed, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| rewrite_toggle, | |
| num_images_per_prompt | |
| ] | |
| outputs = [result, seed, prompt_info] | |
| run_button.click( | |
| fn=infer, | |
| inputs=inputs, | |
| outputs=outputs | |
| ) | |
| prompt.submit( | |
| fn=infer, | |
| inputs=inputs, | |
| outputs=outputs | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |