Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,934 Bytes
94c44d0 2e9b71a 94c44d0 2e9b71a 94c44d0 2e9b71a 94c44d0 2e9b71a 94c44d0 2e9b71a 94c44d0 2e9b71a 94c44d0 2e9b71a 94c44d0 2e9b71a 94c44d0 2e9b71a 94c44d0 2e9b71a 94c44d0 2e9b71a 94c44d0 e8c722f 94c44d0 e8c722f 94c44d0 af36188 94c44d0 af36188 94c44d0 af36188 94c44d0 af36188 94c44d0 2e9b71a 94c44d0 b31e000 94c44d0 09d4d06 94c44d0 09d4d06 94c44d0 09d4d06 94c44d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 |
import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
from diffusers import QwenImageEditPipeline
from diffusers.utils import is_xformers_available
import os
import re
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
#############################
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1')
# Model configuration
REWRITER_MODEL = "Qwen/Qwen1.5-7B-Chat" # Upgraded to 7B for better JSON handling
rewriter_tokenizer = None
rewriter_model = None
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
def load_rewriter():
"""Lazily load the prompt enhancement model"""
global rewriter_tokenizer, rewriter_model
if rewriter_tokenizer is None or rewriter_model is None:
print("🔄 Loading enhancement model...")
rewriter_tokenizer = AutoTokenizer.from_pretrained(REWRITER_MODEL)
rewriter_model = AutoModelForCausalLM.from_pretrained(
REWRITER_MODEL,
torch_dtype=dtype,
device_map="auto",
quantization_config=bnb_config
)
print("✅ Enhancement model loaded")
SYSTEM_PROMPT_EDIT = '''
# Edit Instruction Rewriter
You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image.
## 1. General Principles
- Keep the rewritten instruction **concise** and clear.
- Avoid contradictions, vagueness, or unachievable instructions.
- Maintain the core logic of the original instruction; only enhance clarity and feasibility.
- Ensure new added elements or modifications align with the image's original context and art style.
## 2. Task Types
### Add, Delete, Replace:
- When the input is detailed, only refine grammar and clarity.
- For vague instructions, infer minimal but sufficient details.
- For replacement, use the format: `"Replace X with Y"`.
### Text Editing (e.g., text replacement):
- Enclose text content in quotes, e.g., `Replace "abc" with "xyz"`.
- Preserving the original structure and language—**do not translate** or alter style.
### Human Editing (e.g., change a person’s face/hair):
- Preserve core visual identity (gender, ethnic features).
- Describe expressions in subtle and natural terms.
- Maintain key clothing or styling details unless explicitly replaced.
### Style Transformation:
- If a style is specified, e.g., `Disco style`, rewrite it to encapsulate the essential visual traits.
- Use a fixed template for **coloring/restoration**:
`"Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"`
if applicable.
## 4. Output Format
Please provide the rewritten instruction in a clean `json` format as:
{
"Rewritten": "..."
}
'''
def extract_json_response(model_output: str) -> str:
"""Extract rewritten instruction from potentially messy JSON output"""
try:
# Try to find the JSON portion in the output
start_idx = model_output.find('{')
end_idx = model_output.rfind('}') + 1
if start_idx == -1 or end_idx == 0:
return None
json_str = model_output[start_idx:end_idx]
# Clean up common formatting issues
json_str = re.sub(r'(?<!")\b(\w+)\b(?=":)', r'"\1"', json_str) # Add quotes to keys
json_str = re.sub(r':\s*([^"{\[]|true|false|null)', r': "\1"', json_str) # Add quotes to values
# Parse JSON
data = json.loads(json_str)
# Extract rewritten prompt from possible key variations
possible_keys = [
"Rewritten", "rewritten", "Rewrited", "rewrited",
"Output", "output", "Enhanced", "enhanced"
]
for key in possible_keys:
if key in data:
return data[key].strip()
# Try nested path
if "Response" in data and "Rewritten" in data["Response"]:
return data["Response"]["Rewritten"].strip()
# Fallback to direct extraction
for value in data.values():
if isinstance(value, str) and 10 < len(value) < 500:
return value.strip()
except Exception:
pass
return None
def polish_prompt(original_prompt: str) -> str:
"""Enhanced prompt rewriting using original system prompt with JSON handling"""
load_rewriter()
# Format as Qwen chat
messages = [
{"role": "system", "content": SYSTEM_PROMPT_EDIT},
{"role": "user", "content": original_prompt}
]
text = rewriter_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = rewriter_tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = rewriter_model.generate(
**model_inputs,
max_new_tokens=256, # Maintain token count for good JSON generation
do_sample=True,
temperature=0.6,
top_p=0.9,
no_repeat_ngram_size=2,
pad_token_id=rewriter_tokenizer.eos_token_id
)
# Extract and clean response
enhanced = rewriter_tokenizer.decode(
generated_ids[0][model_inputs.input_ids.shape[1]:],
skip_special_tokens=True
).strip()
# Try to extract JSON content
rewritten_prompt = extract_json_response(enhanced)
if rewritten_prompt:
# Clean up substitutions from the JSON output
rewritten_prompt = re.sub(r'(Replace|Change|Add) "([^"]*)"', r'\1 \2', rewritten_prompt)
rewritten_prompt = rewritten_prompt.replace('\\"', '"')
return rewritten_prompt
# Fallback cleanup if JSON extraction fails
print(f"⚠️ JSON extraction failed, using raw output: {enhanced}")
fallback = re.sub(r'```.*?```', '', enhanced, flags=re.DOTALL) # Remove code blocks
fallback = re.sub(r'[\{\}\[\]"]', '', fallback) # Remove JSON artifacts
fallback = fallback.split('\n')[0] # Take first line
# Try to extract before colon separator
if ': ' in fallback:
return fallback.split(': ')[1].strip()
return fallback.strip()
# Load main image editing pipeline
pipe = QwenImageEditPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit",
torch_dtype=dtype
).to(device)
# Load LoRA weights for acceleration
pipe.load_lora_weights(
"lightx2v/Qwen-Image-Lightning",
weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors"
)
pipe.fuse_lora()
if is_xformers_available():
pipe.enable_xformers_memory_efficient_attention()
else:
print("xformers not available")
def unload_rewriter():
"""Clear enhancement model from memory"""
global rewriter_tokenizer, rewriter_model
if rewriter_model:
del rewriter_tokenizer, rewriter_model
rewriter_tokenizer = None
rewriter_model = None
torch.cuda.empty_cache()
gc.collect()
@spaces.GPU(duration=60)
def infer(
image,
prompt,
seed=42,
randomize_seed=False,
true_guidance_scale=4.0,
num_inference_steps=8,
rewrite_prompt=False,
num_images_per_prompt=1,
):
"""Image editing endpoint with optimized prompt handling"""
original_prompt = prompt
prompt_info = ""
# Handle prompt rewriting
if rewrite_prompt:
try:
enhanced_instruction = polish_prompt(original_prompt)
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>"
f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>"
f"<p><strong>Original:</strong> {original_prompt}</p>"
f"<p><strong>Enhanced:</strong> {enhanced_instruction}</p>"
f"</div>"
)
prompt = enhanced_instruction
except Exception as e:
gr.Warning(f"Prompt enhancement failed: {str(e)}")
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>"
f"<h4 style='margin-top: 0;'>⚠️ Enhancement Not Applied</h4>"
f"<p>Using original prompt. Error: {str(e)}</p>"
f"</div>"
)
else:
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px; background: #f8f9fa'>"
f"<h4 style='margin-top: 0;'>📝 Original Prompt</h4>"
f"<p>{original_prompt}</p>"
f"</div>"
)
# Free VRAM after enhancement
unload_rewriter()
# Set seed for reproducibility
seed_val = seed
if randomize_seed:
seed_val = random.randint(0, 2**32 - 1)
generator = torch.Generator(device=device).manual_seed(seed_val)
try:
# Generate images
edited_images = pipe(
image=image,
prompt=prompt,
negative_prompt=" ",
num_inference_steps=num_inference_steps,
generator=generator,
true_cfg_scale=true_guidance_scale,
num_images_per_prompt=num_images_per_prompt
).images
except Exception as e:
gr.Error(f"Image generation failed: {str(e)}")
prompt_info = (
f"<div style='margin:10px; padding:10px; border-radius:8px; border-left:4px solid #dd2c00; background: #fef5f5'>"
f"<h4 style='margin-top: 0;'><strong>⚠️ Error:</strong> {str(e)}</h4>"
f"</div>"
)
return [], seed_val, prompt_info
return edited_images, seed_val, prompt_info
MAX_SEED = np.iinfo(np.int32).max
examples = [
"Replace the cat with a friendly golden retriever. Make it look happier, and add more background details.",
"Add text 'Qwen - AI for image editing' in Chinese at the bottom center with a small shadow.",
"Change the style to 1970s vintage, add old photo effect, restore any scratches on the wall or window.",
"Remove the blue sky and replace it with a dark night cityscape.",
"""Replace "Qwen" with "通义" in the Image. Ensure Chinese font is used and position it at top left."""
]
with gr.Blocks(title="Qwen Image Editor Fast") as demo:
gr.Markdown("""
<div style="text-align: center;">
<h1>⚡️ Qwen-Image-Edit Lightning Fast 8-STEP</h1>
<p>8-step image editing with lightx2v's LoRA and local prompt enhancement</p>
<p>🚧 Work in progress, further improvements coming soon.</p>
</div>
""")
with gr.Row():
# Input Column
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
prompt = gr.Textbox(label="Edit Instruction", placeholder="e.g. Add a dog to the right side", lines=2)
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown("### Generation Parameters")
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Row():
true_guidance_scale = gr.Slider(
label="Guidance Scale", minimum=1.0, maximum=5.0, step=0.1, value=4.0
)
num_inference_steps = gr.Slider(
label="Inference Steps", minimum=4, maximum=16, step=1, value=8
)
num_images_per_prompt = gr.Slider(
label="Output Images", minimum=1, maximum=4, step=1, value=2
)
rewrite_toggle = gr.Checkbox(
label="Enable AI Prompt Enhancement",
value=True
)
run_button = gr.Button("Generate Edits", variant="primary")
# Output Column
with gr.Column():
result = gr.Gallery(
label="Output Images",
columns=lambda x: 2 if x > 1 else 1,
object_fit="contain",
height="auto"
)
prompt_info = gr.HTML(
"<div style='margin-top:20px; padding:15px; border-radius:8px; background:#f8f9fa'>"
"<p>Prompt details will appear here after generation</p></div>"
)
# gr.Examples(
# examples=examples,
# inputs=[prompt],
# label="Try These Examples",
# cache_examples=True
# )
# Main processing
run_event = run_button.click(
fn=infer,
inputs=[
input_image,
prompt,
seed,
randomize_seed,
true_guidance_scale,
num_inference_steps,
rewrite_toggle,
num_images_per_prompt
],
outputs=[result, seed, prompt_info]
)
prompt.submit(
fn=infer,
inputs=[
input_image,
prompt,
seed,
randomize_seed,
true_guidance_scale,
num_inference_steps,
rewrite_toggle,
num_images_per_prompt
],
outputs=[result, seed, prompt_info]
)
# Vectorize prompt info visibility
run_event.then(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=[prompt_info],
queue=False
)
if __name__ == "__main__":
demo.launch() |