LPX55's picture
Create app_local.py
94c44d0 verified
raw
history blame
11.4 kB
import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
from diffusers import QwenImageEditPipeline
from diffusers.utils import is_xformers_available
import os
import re
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
#############################
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1')
# Model configuration
REWRITER_MODEL = "Qwen/Qwen1.5-1.8B-Chat"
rewriter_tokenizer = None
rewriter_model = None
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
def load_rewriter():
"""Lazily load the prompt enhancement model"""
global rewriter_tokenizer, rewriter_model
if rewriter_tokenizer is None or rewriter_model is None:
print("🔄 Loading enhancement model...")
rewriter_tokenizer = AutoTokenizer.from_pretrained(REWRITER_MODEL)
rewriter_model = AutoModelForCausalLM.from_pretrained(
REWRITER_MODEL,
torch_dtype=dtype,
device_map="auto",
quantization_config=bnb_config
)
print("✅ Enhancement model loaded")
SYSTEM_PROMPT_EDIT = '''
# Edit Instruction Rewriter
You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image.
## 1. General Principles
- Keep the rewritten instruction **concise** and clear.
- Avoid contradictions, vagueness, or unachievable instructions.
- Maintain the core logic of the original instruction; only enhance clarity and feasibility.
- Ensure new added elements or modifications align with the image's original context and art style.
## 2. Task Types
### Add, Delete, Replace:
- When the input is detailed, only refine grammar and clarity.
- For vague instructions, infer minimal but sufficient details.
- For replacement, use the format: `"Replace X with Y"`.
### Text Editing (e.g., text replacement):
- Enclose text content in quotes, e.g., `Replace "abc" with "xyz"`.
- Preserving the original structure and language—**do not translate** or alter style.
### Human Editing (e.g., change a person’s face/hair):
- Preserve core visual identity (gender, ethnic features).
- Describe expressions in subtle and natural terms.
- Maintain key clothing or styling details unless explicitly replaced.
### Style Transformation:
- If a style is specified, e.g., `Disco style`, rewrite it to encapsulate the essential visual traits.
- Use a fixed template for **coloring/restoration**:
`"Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"`
if applicable.
## 4. Output Format
Please provide the rewritten instruction in a clean `json` format as:
{
"Rewritten": "..."
}
'''
def polish_prompt(original_prompt: str) -> str:
"""Enhanced prompt rewriting using Qwen1.5-1.8B"""
load_rewriter()
# Format as Qwen chat with system prompt
messages = [
{"role": "system", "content": SYSTEM_PROMPT_EDIT},
{"role": "user", "content": original_prompt}
]
# Generate enhanced prompt
text = rewriter_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = rewriter_tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = rewriter_model.generate(
**model_inputs,
max_new_tokens=120,
do_sample=True,
temperature=0.7,
top_p=0.95,
no_repeat_ngram_size=2
)
# Extract and clean response
enhanced = rewriter_tokenizer.decode(
generated_ids[0][model_inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
# Clean possible artifacts
enhanced = enhanced.strip()
if enhanced.lower().startswith(("rewritten instruction:", "enhanced:", "output:")):
enhanced = re.split(r':', enhanced, 1)[-1].strip()
# Remove any quotes around the prompt if present
if enhanced.startswith('"') and enhanced.endswith('"'):
enhanced = enhanced[1:-1]
return enhanced
# Load main image editing pipeline
pipe = QwenImageEditPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit",
torch_dtype=dtype
).to(device)
# Load LoRA weights for acceleration
pipe.load_lora_weights(
"lightx2v/Qwen-Image-Lightning",
weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors"
)
pipe.fuse_lora()
if is_xformers_available():
pipe.enable_xformers_memory_efficient_attention()
else:
print("xformers not available")
def unload_rewriter():
"""Clear enhancement model from memory"""
global rewriter_tokenizer, rewriter_model
if rewriter_model:
del rewriter_tokenizer, rewriter_model
rewriter_tokenizer = None
rewriter_model = None
torch.cuda.empty_cache()
gc.collect()
@spaces.GPU(duration=60)
def infer(
image,
prompt,
seed=42,
randomize_seed=False,
true_guidance_scale=4.0,
num_inference_steps=8,
rewrite_prompt=False,
num_images_per_prompt=1,
):
"""Image editing endpoint with optimized prompt handling"""
original_prompt = prompt
prompt_info = ""
# Handle prompt rewriting
if rewrite_prompt:
try:
enhanced_instruction = polish_prompt(original_prompt)
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>"
f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>"
f"<p><strong>Original:</strong> {original_prompt}</p>"
f"<p><strong>Enhanced:</strong> {enhanced_instruction}</p>"
f"</div>"
)
prompt = enhanced_instruction
except Exception as e:
gr.Warning(f"Prompt enhancement failed: {str(e)}")
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>"
f"<h4 style='margin-top: 0;'>⚠️ Enhancement Not Applied</h4>"
f"<p>Using original prompt. Error: {str(e)}</p>"
f"</div>"
)
else:
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px; background: #f8f9fa'>"
f"<h4 style='margin-top: 0;'>📝 Original Prompt</h4>"
f"<p>{original_prompt}</p>"
f"</div>"
)
# Free VRAM after enhancement
unload_rewriter()
# Set seed for reproducibility
seed_val = seed
if randomize_seed:
seed_val = random.randint(0, 2**32 - 1)
generator = torch.Generator(device=device).manual_seed(seed_val)
try:
# Generate images
edited_images = pipe(
image=image,
prompt=prompt,
negative_prompt=" ",
num_inference_steps=num_inference_steps,
generator=generator,
true_cfg_scale=true_guidance_scale,
num_images_per_prompt=num_images_per_prompt
).images
except Exception as e:
gr.Error(f"Image generation failed: {str(e)}")
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #dd2c00; background: #fef5f5'>"
f"<h4 style='margin-top: 0;'><strong>⚠️ Error:</strong> {str(e)}</h4>"
f"</div>"
)
return [], seed_val, prompt_info
return edited_images, seed_val, prompt_info
MAX_SEED = np.iinfo(np.int32).max
examples = [
"Replace the cat with a friendly golden retriever. Make it look happier, and add more background details.",
"Add text 'Qwen - AI for image editing' in Chinese at the bottom center with a small shadow.",
"Change the style to 1970s vintage, add old photo effect, restore any scratches on the wall or window.",
"Remove the blue sky and replace it with a dark night cityscape.",
"""Replace "Qwen" with "通义" in the Image. Ensure Chinese font is used and position it at top left."""
]
with gr.Blocks(title="Qwen Image Editor", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
<div style="text-align: center;">
<h1>⚡️ Qwen-Image-Edit Lightning</h1>
<p>8-step image editing with local prompt enhancement | Powered by NVIDIA H200</p>
</div>
""")
with gr.Row():
# Input Column
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
prompt = gr.Textbox(label="Edit Instruction", placeholder="e.g. Add a dog to the right side", lines=2)
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown("### Generation Parameters")
with gr.Row():
seed = gr.Slider(label="Seed", min=0, max=MAX_SEED, step=1, value=42)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Row():
true_guidance_scale = gr.Slider(
label="Guidance Scale", min=1.0, max=5.0, step=0.1, value=4.0
)
num_inference_steps = gr.Slider(
label="Inference Steps", min=4, max=16, step=1, value=8
)
num_images_per_prompt = gr.Slider(
label="Output Images", min=1, max=4, step=1, value=1
)
rewrite_toggle = gr.Checkbox(
label="Enable AI Prompt Enhancement",
value=True,
info="Uses local Qwen1.5-1.8B model to improve your instructions"
)
run_button = gr.Button("Generate Edits", variant="primary")
# Output Column
with gr.Column():
result = gr.Gallery(
label="Output Images",
columns=lambda x: 2 if x > 1 else 1,
object_fit="contain",
height="auto"
)
prompt_info = gr.HTML(
"<div style='margin-top:20px; padding:15px; border-radius:8px; background:#f8f9fa'>"
"<p>Prompt details will appear here after generation</p></div>"
)
gr.Examples(
examples=examples,
inputs=[prompt],
label="Try These Examples",
cache_examples=True
)
# Main processing handler
inputs = [
input_image,
prompt,
seed,
randomize_seed,
true_guidance_scale,
num_inference_steps,
rewrite_toggle,
num_images_per_prompt
]
outputs = [result, seed, prompt_info]
run_button.click(
fn=infer,
inputs=inputs,
outputs=outputs
)
prompt.submit(
fn=infer,
inputs=inputs,
outputs=outputs
)
if __name__ == "__main__":
demo.launch()