import gradio as gr import numpy as np import random import torch import spaces from PIL import Image from diffusers import QwenImageEditPipeline from diffusers.utils import is_xformers_available import os import re import gc import json # Added json import from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig ############################# os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1') # Model configuration REWRITER_MODEL = "Qwen/Qwen1.5-7B-Chat" # Upgraded to 7B for better JSON handling rewriter_tokenizer = None rewriter_model = None dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" # Quantization configuration bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) def load_rewriter(): """Lazily load the prompt enhancement model""" global rewriter_tokenizer, rewriter_model if rewriter_tokenizer is None or rewriter_model is None: print("🔄 Loading enhancement model...") rewriter_tokenizer = AutoTokenizer.from_pretrained(REWRITER_MODEL) rewriter_model = AutoModelForCausalLM.from_pretrained( REWRITER_MODEL, torch_dtype=dtype, device_map="auto", quantization_config=bnb_config ) print("✅ Enhancement model loaded") SYSTEM_PROMPT_EDIT = ''' # Edit Instruction Rewriter You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image. ## 1. General Principles - Keep the rewritten instruction **concise** and clear. - Avoid contradictions, vagueness, or unachievable instructions. - Maintain the core logic of the original instruction; only enhance clarity and feasibility. - Ensure new added elements or modifications align with the image's original context and art style. ## 2. Task Types ### Add, Delete, Replace: - When the input is detailed, only refine grammar and clarity. - For vague instructions, infer minimal but sufficient details. - For replacement, use the format: `"Replace X with Y"`. ### Text Editing (e.g., text replacement): - Enclose text content in quotes, e.g., `Replace "abc" with "xyz"`. - Preserving the original structure and language—**do not translate** or alter style. ### Human Editing (e.g., change a person’s face/hair): - Preserve core visual identity (gender, ethnic features). - Describe expressions in subtle and natural terms. - Maintain key clothing or styling details unless explicitly replaced. ### Style Transformation: - If a style is specified, e.g., `Disco style`, rewrite it to encapsulate the essential visual traits. - Use a fixed template for **coloring/restoration**: `"Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"` if applicable. ## 4. Output Format Please provide the rewritten instruction in a clean `json` format as: { "Rewritten": "..." } ''' def extract_json_response(model_output: str) -> str: """Extract rewritten instruction from potentially messy JSON output""" # New: Remove code block markers first model_output = re.sub(r'```(?:json)?\s*', '', model_output) try: # Try to find the JSON portion in the output start_idx = model_output.find('{') end_idx = model_output.rfind('}') if start_idx == -1 or end_idx == -1: return None # Expand to the full object including outer braces end_idx += 1 # Include the closing brace json_str = model_output[start_idx:end_idx] # Improved quote handling for values json_str = re.sub(r'(\w+)\s*:', r'"\1":', json_str) # Quote keys json_str = re.sub(r':\s*([^"\s{[]+)', r': "\1"', json_str) # Quote unquoted string values # Parse JSON data = json.loads(json_str) # Extract rewritten prompt from possible key variations possible_keys = [ "Rewritten", "rewritten", "Rewrited", "rewrited", "Output", "output", "Enhanced", "enhanced" ] for key in possible_keys: if key in data: return data[key].strip() # Try nested path if "Response" in data and "Rewritten" in data["Response"]: return data["Response"]["Rewritten"].strip() # Handle nested JSON objects (additional protection) if isinstance(data, dict): for value in data.values(): if isinstance(value, dict) and "Rewritten" in value: return value["Rewritten"].strip() # Try to find any string value that looks like an instruction str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500] if str_values: return str_values[0].strip() except Exception as e: print(f"JSON parse error: {str(e)}") return None def polish_prompt(original_prompt: str) -> str: """Enhanced prompt rewriting using original system prompt with JSON handling""" load_rewriter() # Format as Qwen chat messages = [ {"role": "system", "content": SYSTEM_PROMPT_EDIT}, {"role": "user", "content": original_prompt} ] text = rewriter_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = rewriter_tokenizer(text, return_tensors="pt").to(device) with torch.no_grad(): generated_ids = rewriter_model.generate( **model_inputs, max_new_tokens=150, # Reduced for better quality do_sample=True, temperature=0.4, # Less creative but more focused top_p=0.9, no_repeat_ngram_size=3, pad_token_id=rewriter_tokenizer.eos_token_id ) # Extract and clean response enhanced = rewriter_tokenizer.decode( generated_ids[0][model_inputs.input_ids.shape[1]:], skip_special_tokens=True ).strip() # New: Last-resort JSON content extraction json_str = enhanced if '```' in enhanced: parts = enhanced.split('```') if len(parts) >= 3: json_str = parts[1] # Take content between first set of ``` # Try to extract JSON content rewritten_prompt = extract_json_response(json_str if '```' in enhanced else enhanced) if rewritten_prompt: # Clean up remaining artifacts rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt) rewritten_prompt = rewritten_prompt.replace('\\"', '"').replace('\\n', ' ') return rewritten_prompt # Fallback cleanup if JSON extraction fails if '```' in enhanced: # Extract content from code blocks parts = enhanced.split('```') if len(parts) >= 3: rewritten_prompt = parts[1].strip() else: rewritten_prompt = enhanced else: rewritten_prompt = enhanced # Improved cleaning of fallback output rewritten_prompt = re.sub(r'.*{.*}.*', '', rewritten_prompt) rewritten_prompt = re.sub(r'\s\s+', ' ', rewritten_prompt).strip() if ': ' in rewritten_prompt: rewritten_prompt = rewritten_prompt.split(': ', 1)[-1].strip() return rewritten_prompt[:200] # Ensure reasonable length # Load main image editing pipeline pipe = QwenImageEditPipeline.from_pretrained( "Qwen/Qwen-Image-Edit", torch_dtype=dtype ).to(device) # Load LoRA weights for acceleration pipe.load_lora_weights( "lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors" ) pipe.fuse_lora() if is_xformers_available(): pipe.enable_xformers_memory_efficient_attention() else: print("xformers not available") def unload_rewriter(): """Clear enhancement model from memory""" global rewriter_tokenizer, rewriter_model if rewriter_model: del rewriter_tokenizer, rewriter_model rewriter_tokenizer = None rewriter_model = None torch.cuda.empty_cache() gc.collect() @spaces.GPU(duration=60) def infer( image, prompt, seed=42, randomize_seed=False, true_guidance_scale=1.0, num_inference_steps=8, rewrite_prompt=False, num_images_per_prompt=1, ): """Image editing endpoint with optimized prompt handling""" original_prompt = prompt prompt_info = "" # Handle prompt rewriting if rewrite_prompt: try: enhanced_instruction = polish_prompt(original_prompt) prompt_info = ( f"
Original: {original_prompt}
" f"Enhanced: {enhanced_instruction}
" f"Using original prompt. Error: {str(e)[:100]}
" f"{original_prompt}
" f"{str(e)[:200]}
" f"8-step inferencing • Local Prompt Enhancement • H200 Optimized