|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import random |
|
|
import torch |
|
|
import spaces |
|
|
from PIL import Image |
|
|
from diffusers import QwenImageEditPipeline, FlowMatchEulerDiscreteScheduler |
|
|
from diffusers.utils import is_xformers_available |
|
|
from presets import PRESETS, get_preset_choices, get_preset_info |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import re |
|
|
import gc |
|
|
import math |
|
|
import json |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
|
import logging |
|
|
|
|
|
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') |
|
|
os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1') |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
handlers=[logging.StreamHandler(sys.stdout)] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
REWRITER_MODEL = "Qwen/Qwen1.5-4B-Chat" |
|
|
dtype = torch.bfloat16 |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
LOC = os.getenv("QIE") |
|
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_use_double_quant=True |
|
|
) |
|
|
|
|
|
rewriter_model = AutoModelForCausalLM.from_pretrained( |
|
|
REWRITER_MODEL, |
|
|
torch_dtype=dtype, |
|
|
device_map="auto", |
|
|
quantization_config=bnb_config, |
|
|
) |
|
|
|
|
|
|
|
|
print("🔄 Loading prompt enhancement model...") |
|
|
rewriter_tokenizer = AutoTokenizer.from_pretrained(REWRITER_MODEL) |
|
|
|
|
|
print("✅ Enhancement model loaded and ready!") |
|
|
|
|
|
SYSTEM_PROMPT_EDIT = ''' |
|
|
# Edit Instruction Rewriter |
|
|
You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image. |
|
|
## 1. General Principles |
|
|
- Keep the rewritten instruction **concise** and clear. |
|
|
- Avoid contradictions, vagueness, or unachievable instructions. |
|
|
- Maintain the core logic of the original instruction; only enhance clarity and feasibility. |
|
|
- Ensure new added elements or modifications align with the image's original context and art style. |
|
|
## 2. Task Types |
|
|
### Add, Delete, Replace: |
|
|
- When the input is detailed, only refine grammar and clarity. |
|
|
- For vague instructions, infer minimal but sufficient details. |
|
|
- For replacement, use the format: `"Replace X with Y"`. |
|
|
### Text Editing (e.g., text replacement): |
|
|
- Enclose text content in quotes, e.g., `Replace "abc" with "xyz"`. |
|
|
- Preserving the original structure and language—**do not translate** or alter style. |
|
|
### Human Editing (e.g., change a person’s face/hair): |
|
|
- Preserve core visual identity (gender, ethnic features). |
|
|
- Describe expressions in subtle and natural terms. |
|
|
- Maintain key clothing or styling details unless explicitly replaced. |
|
|
### Style Transformation: |
|
|
- If a style is specified, e.g., `Disco style`, rewrite it to encapsulate the essential visual traits. |
|
|
- Use a fixed template for **coloring/restoration**: |
|
|
`"Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"` |
|
|
if applicable. |
|
|
## 4. Output Format |
|
|
Please provide the rewritten instruction in a clean `json` format as: |
|
|
{ |
|
|
"Rewritten": "..." |
|
|
} |
|
|
''' |
|
|
|
|
|
def extract_json_response(model_output: str) -> str: |
|
|
"""Extract rewritten instruction from potentially messy JSON output""" |
|
|
|
|
|
model_output = re.sub(r'```(?:json)?\s*', '', model_output) |
|
|
try: |
|
|
|
|
|
start_idx = model_output.find('{') |
|
|
end_idx = model_output.rfind('}') |
|
|
|
|
|
if start_idx == -1 or end_idx == -1 or start_idx >= end_idx: |
|
|
print(f"No valid JSON structure found in output. Start: {start_idx}, End: {end_idx}") |
|
|
return None |
|
|
|
|
|
end_idx += 1 |
|
|
json_str = model_output[start_idx:end_idx] |
|
|
|
|
|
json_str = json_str.strip() |
|
|
|
|
|
try: |
|
|
data = json.loads(json_str) |
|
|
except json.JSONDecodeError as e: |
|
|
print(f"Direct JSON parsing failed: {e}") |
|
|
|
|
|
|
|
|
json_str = re.sub(r'([^{}[\],\s"]+)(?=\s*:)', r'"\1"', json_str) |
|
|
|
|
|
json_str = re.sub(r',(\s*[}\]])', r'\1', json_str) |
|
|
|
|
|
data = json.loads(json_str) |
|
|
|
|
|
possible_keys = [ |
|
|
"Rewritten", "rewritten", "Rewrited", "rewrited", "Rewrittent", |
|
|
"Output", "output", "Enhanced", "enhanced" |
|
|
] |
|
|
for key in possible_keys: |
|
|
if key in data: |
|
|
return data[key].strip() |
|
|
|
|
|
if "Response" in data and "Rewritten" in data["Response"]: |
|
|
return data["Response"]["Rewritten"].strip() |
|
|
|
|
|
if isinstance(data, dict): |
|
|
for value in data.values(): |
|
|
if isinstance(value, dict) and "Rewritten" in value: |
|
|
return value["Rewritten"].strip() |
|
|
|
|
|
str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500] |
|
|
if str_values: |
|
|
return str_values[0].strip() |
|
|
except Exception as e: |
|
|
print(f"JSON parse error: {str(e)}") |
|
|
print(f"Model output was: {model_output}") |
|
|
return None |
|
|
|
|
|
def polish_prompt(original_prompt: str) -> str: |
|
|
"""Enhanced prompt rewriting using original system prompt with JSON handling""" |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": SYSTEM_PROMPT_EDIT}, |
|
|
{"role": "user", "content": original_prompt} |
|
|
] |
|
|
text = rewriter_tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
model_inputs = rewriter_tokenizer(text, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
generated_ids = rewriter_model.generate( |
|
|
**model_inputs, |
|
|
max_new_tokens=256, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.8, |
|
|
repetition_penalty=1.1, |
|
|
no_repeat_ngram_size=3, |
|
|
pad_token_id=rewriter_tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
enhanced = rewriter_tokenizer.decode( |
|
|
generated_ids[0][model_inputs.input_ids.shape[1]:], |
|
|
skip_special_tokens=True |
|
|
).strip() |
|
|
print(f"Model raw output: {enhanced}") |
|
|
|
|
|
rewritten_prompt = extract_json_response(enhanced) |
|
|
if rewritten_prompt: |
|
|
|
|
|
rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt) |
|
|
rewritten_prompt = rewritten_prompt.replace('\\"', '"').replace('\\n', ' ') |
|
|
return rewritten_prompt |
|
|
else: |
|
|
|
|
|
if '```' in enhanced: |
|
|
parts = enhanced.split('```') |
|
|
if len(parts) >= 2: |
|
|
rewritten_prompt = parts[1].strip() |
|
|
else: |
|
|
rewritten_prompt = enhanced |
|
|
else: |
|
|
rewritten_prompt = enhanced |
|
|
|
|
|
rewritten_prompt = re.sub(r'\s\s+', ' ', rewritten_prompt).strip() |
|
|
if ': ' in rewritten_prompt: |
|
|
rewritten_prompt = rewritten_prompt.split(': ', 1)[-1].strip() |
|
|
return rewritten_prompt[:200] if rewritten_prompt else original_prompt |
|
|
|
|
|
|
|
|
scheduler_config = { |
|
|
"base_image_seq_len": 256, |
|
|
"base_shift": math.log(3), |
|
|
"invert_sigmas": False, |
|
|
"max_image_seq_len": 8192, |
|
|
"max_shift": math.log(3), |
|
|
"num_train_timesteps": 1000, |
|
|
"shift": 1.0, |
|
|
"shift_terminal": None, |
|
|
"stochastic_sampling": False, |
|
|
"time_shift_type": "exponential", |
|
|
"use_beta_sigmas": False, |
|
|
"use_dynamic_shifting": True, |
|
|
"use_exponential_sigmas": False, |
|
|
"use_karras_sigmas": False, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) |
|
|
|
|
|
|
|
|
|
|
|
pipe = QwenImageEditPipeline.from_pretrained( |
|
|
LOC, |
|
|
scheduler=scheduler, |
|
|
torch_dtype=dtype |
|
|
).to(device) |
|
|
|
|
|
|
|
|
pipe.load_lora_weights( |
|
|
"lightx2v/Qwen-Image-Lightning", |
|
|
weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors" |
|
|
) |
|
|
pipe.fuse_lora() |
|
|
|
|
|
if is_xformers_available(): |
|
|
pipe.enable_xformers_memory_efficient_attention() |
|
|
else: |
|
|
print("xformers not available") |
|
|
|
|
|
|
|
|
def update_prompt_preview(preset_type, base_prompt): |
|
|
"""Update the prompt preview display based on selected preset and base prompt""" |
|
|
if preset_type and preset_type in PRESETS: |
|
|
preset = PRESETS[preset_type] |
|
|
preview_text = f"**Preset: {preset_type}**\n\n" |
|
|
preview_text += f"*{preset['description']}*\n\n" |
|
|
preview_text += "**Generated Prompts:**\n" |
|
|
|
|
|
for i, preset_prompt in enumerate(preset["prompts"], 1): |
|
|
full_prompt = f"{base_prompt}, {preset_prompt}" |
|
|
preview_text += f"{i}. {full_prompt}\n" |
|
|
|
|
|
return preview_text |
|
|
else: |
|
|
return "Select a preset above to see how your base prompt will be modified for batch generation." |
|
|
|
|
|
|
|
|
@spaces.GPU() |
|
|
def infer( |
|
|
image, |
|
|
prompt, |
|
|
seed=42, |
|
|
randomize_seed=False, |
|
|
true_guidance_scale=4.0, |
|
|
num_inference_steps=8, |
|
|
rewrite_prompt=True, |
|
|
num_images_per_prompt=1, |
|
|
preset_type=None, |
|
|
progress=gr.Progress(track_tqdm=True), |
|
|
): |
|
|
"""Image editing endpoint with optimized prompt handling""" |
|
|
|
|
|
def resize_image(pil_image, max_size=1024): |
|
|
"""Resize image to maximum dimension of 1024px while maintaining aspect ratio""" |
|
|
try: |
|
|
if pil_image is None: |
|
|
return pil_image |
|
|
width, height = pil_image.size |
|
|
max_dimension = max(width, height) |
|
|
if max_dimension <= max_size: |
|
|
return pil_image |
|
|
|
|
|
scale = max_size / max_dimension |
|
|
new_width = int(width * scale) |
|
|
new_height = int(height * scale) |
|
|
|
|
|
resized_image = pil_image.resize((new_width, new_height), Image.LANCZOS) |
|
|
print(f"📝 Image resized from {width}x{height} to {new_width}x{new_height}") |
|
|
return resized_image |
|
|
except Exception as e: |
|
|
print(f"⚠️ Image resize failed: {e}") |
|
|
return pil_image |
|
|
|
|
|
|
|
|
def add_noise_to_image(pil_image, noise_level=0.05): |
|
|
"""Add slight noise to image to create variation in outputs""" |
|
|
try: |
|
|
if pil_image is None: |
|
|
return pil_image |
|
|
img_array = np.array(pil_image).astype(np.float32) / 255.0 |
|
|
noise = np.random.normal(0, noise_level, img_array.shape) |
|
|
noisy_array = img_array + noise |
|
|
|
|
|
noisy_array = np.clip(noisy_array, 0, 1) |
|
|
|
|
|
noisy_array = (noisy_array * 255).astype(np.uint8) |
|
|
return Image.fromarray(noisy_array) |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not add noise to image: {e}") |
|
|
return pil_image |
|
|
|
|
|
|
|
|
image = resize_image(image, max_size=1024) |
|
|
original_prompt = prompt |
|
|
prompt_info = "" |
|
|
|
|
|
|
|
|
if preset_type and preset_type in PRESETS: |
|
|
preset = PRESETS[preset_type] |
|
|
batch_prompts = [f"{original_prompt}, {preset_prompt}" for preset_prompt in preset["prompts"]] |
|
|
num_images_per_prompt = preset["count"] |
|
|
prompt_info = ( |
|
|
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #2196F3; background: #f0f8ff'>" |
|
|
f"<h4 style='margin-top: 0;'>🎨 Preset: {preset_type}</h4>" |
|
|
f"<p>{preset['description']}</p>" |
|
|
f"<p><strong>Base Prompt:</strong> {original_prompt}</p>" |
|
|
f"</div>" |
|
|
) |
|
|
print(f"Using preset: {preset_type} with {len(batch_prompts)} variations") |
|
|
else: |
|
|
batch_prompts = [prompt] |
|
|
|
|
|
|
|
|
if rewrite_prompt: |
|
|
try: |
|
|
enhanced_instruction = polish_prompt(original_prompt) |
|
|
if enhanced_instruction and enhanced_instruction != original_prompt: |
|
|
prompt_info = ( |
|
|
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>" |
|
|
f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>" |
|
|
f"<p><strong>Original:</strong> {original_prompt}</p>" |
|
|
f"<p><strong style='color:#2E7D32;'>Enhanced:</strong> {enhanced_instruction}</p>" |
|
|
f"</div>" |
|
|
) |
|
|
batch_prompts = [enhanced_instruction] |
|
|
else: |
|
|
prompt_info = ( |
|
|
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF9800; background: #fff8f0'>" |
|
|
f"<h4 style='margin-top: 0;'>📝 Prompt Enhancement</h4>" |
|
|
f"<p>No enhancement applied or enhancement failed</p>" |
|
|
f"</div>" |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Prompt enhancement error: {str(e)}") |
|
|
gr.Warning(f"Prompt enhancement failed: {str(e)}") |
|
|
prompt_info = ( |
|
|
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>" |
|
|
f"<h4 style='margin-top: 0;'>⚠️ Enhancement Not Applied</h4>" |
|
|
f"<p>Using original prompt. Error: {str(e)[:100]}</p>" |
|
|
f"</div>" |
|
|
) |
|
|
else: |
|
|
prompt_info = ( |
|
|
f"<div style='margin:10px; padding:10px; border-radius:8px; background: #f8f9fa'>" |
|
|
f"<h4 style='margin-top: 0;'>📝 Original Prompt</h4>" |
|
|
f"<p>{original_prompt}</p>" |
|
|
f"</div>" |
|
|
) |
|
|
|
|
|
|
|
|
base_seed = seed if not randomize_seed else random.randint(0, MAX_SEED) |
|
|
|
|
|
try: |
|
|
edited_images = [] |
|
|
|
|
|
|
|
|
for i, current_prompt in enumerate(batch_prompts): |
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(base_seed + i*1000) |
|
|
|
|
|
|
|
|
if i == 0 and len(batch_prompts) == 1: |
|
|
input_image = image |
|
|
else: |
|
|
input_image = add_noise_to_image(image, noise_level=0.01 + i*0.003) |
|
|
|
|
|
|
|
|
varied_guidance = true_guidance_scale + random.uniform(-0.2, 0.2) |
|
|
varied_guidance = max(1.0, min(10.0, varied_guidance)) |
|
|
|
|
|
|
|
|
result = pipe( |
|
|
image=input_image, |
|
|
prompt=current_prompt, |
|
|
negative_prompt=" ", |
|
|
num_inference_steps=num_inference_steps, |
|
|
generator=generator, |
|
|
true_cfg_scale=varied_guidance, |
|
|
num_images_per_prompt=1 |
|
|
).images |
|
|
edited_images.extend(result) |
|
|
|
|
|
print(f"Generated image {i+1}/{len(batch_prompts)} with prompt: {current_prompt[:50]}...") |
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
return edited_images, base_seed, prompt_info |
|
|
except Exception as e: |
|
|
|
|
|
if device == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
gr.Error(f"Image generation failed: {str(e)}") |
|
|
return [], base_seed, ( |
|
|
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #dd2c00; background: #fef5f5'>" |
|
|
f"<h4 style='margin-top: 0;'>⚠️ Processing Error</h4>" |
|
|
f"<p>{str(e)[:200]}</p>" |
|
|
f"</div>" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Qwen Image Edit - Fast Lightning Mode w/ Batch") as demo: |
|
|
gr.Markdown(""" |
|
|
<div style="text-align: center; background: linear-gradient(to right, #3a7bd5, #00d2ff); color: white; padding: 20px; border-radius: 8px;"> |
|
|
<h1 style="margin-bottom: 5px;">⚡️ Qwen-Image-Edit Lightning</h1> |
|
|
<p>✨ 8-step inferencing with lightx2v's LoRA.</p> |
|
|
<p>📝 Local Prompt Enhancement, Batched Multi-image Generation, 🎨 Preset Batches</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
input_image = gr.Image( |
|
|
label="Source Image", |
|
|
type="pil", |
|
|
height=300 |
|
|
) |
|
|
|
|
|
prompt = gr.Textbox( |
|
|
label="Edit Instructions / Base Prompt", |
|
|
placeholder="e.g. Replace the background with a beach sunset... When a preset is selected, use as the base prompt, e.g. the lamborghini", |
|
|
lines=2, |
|
|
max_lines=4, |
|
|
scale=2 |
|
|
) |
|
|
|
|
|
preset_dropdown = gr.Dropdown( |
|
|
choices=get_preset_choices(), |
|
|
value=None, |
|
|
label="Preset Batch Generation", |
|
|
interactive=True |
|
|
) |
|
|
rewrite_toggle = gr.Checkbox( |
|
|
label="Enable Prompt Enhancement", |
|
|
value=True, |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
|
|
|
prompt_preview = gr.Textbox( |
|
|
label="📋 Prompt Preview", |
|
|
interactive=False, |
|
|
lines=6, |
|
|
max_lines=10, |
|
|
value="Enter a base prompt and select a preset above to see how your prompt will be modified for batch generation.", |
|
|
placeholder="Prompt preview will appear here..." |
|
|
) |
|
|
run_button = gr.Button( |
|
|
"Generate Edit(s)", |
|
|
variant="primary" |
|
|
) |
|
|
with gr.Accordion("Advanced Parameters", open=False): |
|
|
with gr.Row(): |
|
|
seed = gr.Slider( |
|
|
label="Seed", |
|
|
minimum=0, |
|
|
maximum=MAX_SEED, |
|
|
step=1, |
|
|
value=42 |
|
|
) |
|
|
randomize_seed = gr.Checkbox( |
|
|
label="Random Seed", |
|
|
value=True |
|
|
) |
|
|
with gr.Row(): |
|
|
true_guidance_scale = gr.Slider( |
|
|
label="Guidance Scale", |
|
|
minimum=1.0, |
|
|
maximum=10.0, |
|
|
step=0.1, |
|
|
value=4.0 |
|
|
) |
|
|
num_inference_steps = gr.Slider( |
|
|
label="Inference Steps", |
|
|
minimum=4, |
|
|
maximum=16, |
|
|
step=1, |
|
|
value=6 |
|
|
) |
|
|
num_images_per_prompt = gr.Slider( |
|
|
label="Output Count (Manual)", |
|
|
minimum=1, |
|
|
maximum=4, |
|
|
step=1, |
|
|
value=1 |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column(scale=2): |
|
|
result = gr.Gallery( |
|
|
label="Edited Images", |
|
|
columns=lambda x: min(x, 2), |
|
|
height=500, |
|
|
object_fit="cover", |
|
|
preview=True |
|
|
) |
|
|
prompt_info = gr.HTML( |
|
|
value="<div style='padding:15px; margin-top:15px'>" |
|
|
"Prompt details will appear after generation. Ability to edit Preset Prompts on the fly will be implemented shortly.</div>" |
|
|
) |
|
|
|
|
|
|
|
|
preset_dropdown.change( |
|
|
fn=update_prompt_preview, |
|
|
inputs=[preset_dropdown, prompt], |
|
|
outputs=prompt_preview |
|
|
) |
|
|
|
|
|
prompt.change( |
|
|
fn=update_prompt_preview, |
|
|
inputs=[preset_dropdown, prompt], |
|
|
outputs=prompt_preview |
|
|
) |
|
|
|
|
|
|
|
|
inputs = [ |
|
|
input_image, |
|
|
prompt, |
|
|
seed, |
|
|
randomize_seed, |
|
|
true_guidance_scale, |
|
|
num_inference_steps, |
|
|
rewrite_toggle, |
|
|
num_images_per_prompt, |
|
|
preset_dropdown |
|
|
] |
|
|
outputs = [result, seed, prompt_info] |
|
|
|
|
|
run_button.click( |
|
|
fn=infer, |
|
|
inputs=inputs, |
|
|
outputs=outputs |
|
|
) |
|
|
prompt.submit( |
|
|
fn=infer, |
|
|
inputs=inputs, |
|
|
outputs=outputs |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue(max_size=5).launch() |