Spaces:
Running
on
Zero
Running
on
Zero
Update app_local.py
Browse files- app_local.py +80 -56
app_local.py
CHANGED
|
@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
|
|
| 28 |
REWRITER_MODEL = "Qwen/Qwen1.5-4B-Chat" # Upgraded to 4B for better JSON handling
|
| 29 |
dtype = torch.bfloat16
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 31 |
|
| 32 |
# Quantization configuration
|
| 33 |
bnb_config = BitsAndBytesConfig(
|
|
@@ -82,29 +83,40 @@ Please provide the rewritten instruction in a clean `json` format as:
|
|
| 82 |
}
|
| 83 |
'''
|
| 84 |
|
|
|
|
| 85 |
def extract_json_response(model_output: str) -> str:
|
| 86 |
"""Extract rewritten instruction from potentially messy JSON output"""
|
| 87 |
-
#
|
| 88 |
model_output = re.sub(r'```(?:json)?\s*', '', model_output)
|
| 89 |
-
|
| 90 |
try:
|
| 91 |
-
#
|
| 92 |
start_idx = model_output.find('{')
|
| 93 |
end_idx = model_output.rfind('}')
|
| 94 |
-
if start_idx == -1 or end_idx == -1:
|
| 95 |
-
return None
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
# Expand to the full object including outer braces
|
| 98 |
end_idx += 1 # Include the closing brace
|
| 99 |
-
|
| 100 |
json_str = model_output[start_idx:end_idx]
|
| 101 |
|
| 102 |
-
#
|
| 103 |
-
json_str =
|
| 104 |
-
json_str = re.sub(r':\s*([^"\s{[]+)', r': "\1"', json_str) # Quote unquoted string values
|
| 105 |
|
| 106 |
-
#
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
# Extract rewritten prompt from possible key variations
|
| 110 |
possible_keys = [
|
|
@@ -124,20 +136,20 @@ def extract_json_response(model_output: str) -> str:
|
|
| 124 |
for value in data.values():
|
| 125 |
if isinstance(value, dict) and "Rewritten" in value:
|
| 126 |
return value["Rewritten"].strip()
|
| 127 |
-
|
| 128 |
# Try to find any string value that looks like an instruction
|
| 129 |
str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500]
|
| 130 |
if str_values:
|
| 131 |
return str_values[0].strip()
|
| 132 |
-
|
| 133 |
except Exception as e:
|
| 134 |
print(f"JSON parse error: {str(e)}")
|
| 135 |
-
|
| 136 |
return None
|
| 137 |
|
|
|
|
| 138 |
def polish_prompt(original_prompt: str) -> str:
|
| 139 |
"""Enhanced prompt rewriting using original system prompt with JSON handling"""
|
| 140 |
-
# load_rewriter()
|
| 141 |
|
| 142 |
# Format as Qwen chat
|
| 143 |
messages = [
|
|
@@ -156,11 +168,11 @@ def polish_prompt(original_prompt: str) -> str:
|
|
| 156 |
with torch.no_grad():
|
| 157 |
generated_ids = rewriter_model.generate(
|
| 158 |
**model_inputs,
|
| 159 |
-
max_new_tokens=256,
|
| 160 |
do_sample=True,
|
| 161 |
-
temperature=0.5,
|
| 162 |
top_p=0.8,
|
| 163 |
-
repetition_penalty=
|
| 164 |
no_repeat_ngram_size=3,
|
| 165 |
pad_token_id=rewriter_tokenizer.eos_token_id
|
| 166 |
)
|
|
@@ -171,41 +183,34 @@ def polish_prompt(original_prompt: str) -> str:
|
|
| 171 |
skip_special_tokens=True
|
| 172 |
).strip()
|
| 173 |
|
| 174 |
-
|
| 175 |
-
json_str = enhanced
|
| 176 |
-
if '```' in enhanced:
|
| 177 |
-
parts = enhanced.split('```')
|
| 178 |
-
if len(parts) >= 3:
|
| 179 |
-
json_str = parts[1] # Take content between first set of ```
|
| 180 |
|
| 181 |
# Try to extract JSON content
|
| 182 |
-
rewritten_prompt = extract_json_response(
|
| 183 |
|
| 184 |
if rewritten_prompt:
|
| 185 |
# Clean up remaining artifacts
|
| 186 |
rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt)
|
| 187 |
rewritten_prompt = rewritten_prompt.replace('\\"', '"').replace('\\n', ' ')
|
| 188 |
return rewritten_prompt
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
| 196 |
else:
|
| 197 |
rewritten_prompt = enhanced
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
return rewritten_prompt[:200] # Ensure reasonable length
|
| 208 |
-
|
| 209 |
# Load main image editing pipeline
|
| 210 |
pipe = QwenImageEditPipeline.from_pretrained(
|
| 211 |
"Qwen/Qwen-Image-Edit",
|
|
@@ -233,7 +238,6 @@ else:
|
|
| 233 |
# rewriter_model = None
|
| 234 |
# torch.cuda.empty_cache()
|
| 235 |
# gc.collect()
|
| 236 |
-
|
| 237 |
@spaces.GPU()
|
| 238 |
def infer(
|
| 239 |
image,
|
|
@@ -246,6 +250,11 @@ def infer(
|
|
| 246 |
num_images_per_prompt=1,
|
| 247 |
):
|
| 248 |
"""Image editing endpoint with optimized prompt handling"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
original_prompt = prompt
|
| 250 |
prompt_info = ""
|
| 251 |
|
|
@@ -253,15 +262,24 @@ def infer(
|
|
| 253 |
if rewrite_prompt:
|
| 254 |
try:
|
| 255 |
enhanced_instruction = polish_prompt(original_prompt)
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
except Exception as e:
|
|
|
|
| 265 |
gr.Warning(f"Prompt enhancement failed: {str(e)}")
|
| 266 |
prompt_info = (
|
| 267 |
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>"
|
|
@@ -277,7 +295,6 @@ def infer(
|
|
| 277 |
f"</div>"
|
| 278 |
)
|
| 279 |
|
| 280 |
-
|
| 281 |
# Set seed for reproducibility
|
| 282 |
seed_val = seed if not randomize_seed else random.randint(0, MAX_SEED)
|
| 283 |
generator = torch.Generator(device=device).manual_seed(seed_val)
|
|
@@ -293,9 +310,18 @@ def infer(
|
|
| 293 |
true_cfg_scale=true_guidance_scale,
|
| 294 |
num_images_per_prompt=num_images_per_prompt
|
| 295 |
).images
|
| 296 |
-
return edited_images, seed_val, prompt_info
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
gr.Error(f"Image generation failed: {str(e)}")
|
| 300 |
return [], seed_val, (
|
| 301 |
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #dd2c00; background: #fef5f5'>"
|
|
@@ -304,8 +330,6 @@ def infer(
|
|
| 304 |
f"</div>"
|
| 305 |
)
|
| 306 |
|
| 307 |
-
MAX_SEED = np.iinfo(np.int32).max
|
| 308 |
-
|
| 309 |
with gr.Blocks(title="Qwen Image Editor Fast") as demo:
|
| 310 |
gr.Markdown("""
|
| 311 |
<div style="text-align: center; background: linear-gradient(to right, #3a7bd5, #00d2ff); color: white; padding: 20px; border-radius: 8px;">
|
|
|
|
| 28 |
REWRITER_MODEL = "Qwen/Qwen1.5-4B-Chat" # Upgraded to 4B for better JSON handling
|
| 29 |
dtype = torch.bfloat16
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 32 |
|
| 33 |
# Quantization configuration
|
| 34 |
bnb_config = BitsAndBytesConfig(
|
|
|
|
| 83 |
}
|
| 84 |
'''
|
| 85 |
|
| 86 |
+
|
| 87 |
def extract_json_response(model_output: str) -> str:
|
| 88 |
"""Extract rewritten instruction from potentially messy JSON output"""
|
| 89 |
+
# Remove code block markers first
|
| 90 |
model_output = re.sub(r'```(?:json)?\s*', '', model_output)
|
|
|
|
| 91 |
try:
|
| 92 |
+
# Find the JSON portion in the output
|
| 93 |
start_idx = model_output.find('{')
|
| 94 |
end_idx = model_output.rfind('}')
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
# Fix the condition - check if brackets were found
|
| 97 |
+
if start_idx == -1 or end_idx == -1 or start_idx >= end_idx:
|
| 98 |
+
print(f"No valid JSON structure found in output. Start: {start_idx}, End: {end_idx}")
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
# Expand to the full object including outer braces
|
| 102 |
end_idx += 1 # Include the closing brace
|
|
|
|
| 103 |
json_str = model_output[start_idx:end_idx]
|
| 104 |
|
| 105 |
+
# Handle potential markdown or other formatting
|
| 106 |
+
json_str = json_str.strip()
|
|
|
|
| 107 |
|
| 108 |
+
# Try to parse JSON directly first
|
| 109 |
+
try:
|
| 110 |
+
data = json.loads(json_str)
|
| 111 |
+
except json.JSONDecodeError as e:
|
| 112 |
+
print(f"Direct JSON parsing failed: {e}")
|
| 113 |
+
# If direct parsing fails, try cleanup
|
| 114 |
+
# Quote keys properly
|
| 115 |
+
json_str = re.sub(r'([^{}[\],\s"]+)(?=\s*:)', r'"\1"', json_str)
|
| 116 |
+
# Remove any trailing commas that might cause issues
|
| 117 |
+
json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
|
| 118 |
+
# Try parsing again
|
| 119 |
+
data = json.loads(json_str)
|
| 120 |
|
| 121 |
# Extract rewritten prompt from possible key variations
|
| 122 |
possible_keys = [
|
|
|
|
| 136 |
for value in data.values():
|
| 137 |
if isinstance(value, dict) and "Rewritten" in value:
|
| 138 |
return value["Rewritten"].strip()
|
| 139 |
+
|
| 140 |
# Try to find any string value that looks like an instruction
|
| 141 |
str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500]
|
| 142 |
if str_values:
|
| 143 |
return str_values[0].strip()
|
| 144 |
+
|
| 145 |
except Exception as e:
|
| 146 |
print(f"JSON parse error: {str(e)}")
|
| 147 |
+
print(f"Model output was: {model_output}")
|
| 148 |
return None
|
| 149 |
|
| 150 |
+
|
| 151 |
def polish_prompt(original_prompt: str) -> str:
|
| 152 |
"""Enhanced prompt rewriting using original system prompt with JSON handling"""
|
|
|
|
| 153 |
|
| 154 |
# Format as Qwen chat
|
| 155 |
messages = [
|
|
|
|
| 168 |
with torch.no_grad():
|
| 169 |
generated_ids = rewriter_model.generate(
|
| 170 |
**model_inputs,
|
| 171 |
+
max_new_tokens=256,
|
| 172 |
do_sample=True,
|
| 173 |
+
temperature=0.5,
|
| 174 |
top_p=0.8,
|
| 175 |
+
repetition_penalty=1.1,
|
| 176 |
no_repeat_ngram_size=3,
|
| 177 |
pad_token_id=rewriter_tokenizer.eos_token_id
|
| 178 |
)
|
|
|
|
| 183 |
skip_special_tokens=True
|
| 184 |
).strip()
|
| 185 |
|
| 186 |
+
print(f"Model raw output: {enhanced}") # Debug logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
# Try to extract JSON content
|
| 189 |
+
rewritten_prompt = extract_json_response(enhanced)
|
| 190 |
|
| 191 |
if rewritten_prompt:
|
| 192 |
# Clean up remaining artifacts
|
| 193 |
rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt)
|
| 194 |
rewritten_prompt = rewritten_prompt.replace('\\"', '"').replace('\\n', ' ')
|
| 195 |
return rewritten_prompt
|
| 196 |
+
else:
|
| 197 |
+
# Fallback: try to extract from code blocks or just return cleaned content
|
| 198 |
+
if '```' in enhanced:
|
| 199 |
+
parts = enhanced.split('```')
|
| 200 |
+
if len(parts) >= 2:
|
| 201 |
+
rewritten_prompt = parts[1].strip()
|
| 202 |
+
else:
|
| 203 |
+
rewritten_prompt = enhanced
|
| 204 |
else:
|
| 205 |
rewritten_prompt = enhanced
|
| 206 |
+
|
| 207 |
+
# Basic cleanup
|
| 208 |
+
rewritten_prompt = re.sub(r'\s\s+', ' ', rewritten_prompt).strip()
|
| 209 |
+
if ': ' in rewritten_prompt:
|
| 210 |
+
rewritten_prompt = rewritten_prompt.split(': ', 1)[-1].strip()
|
| 211 |
+
|
| 212 |
+
return rewritten_prompt[:200] if rewritten_prompt else original_prompt
|
| 213 |
+
|
|
|
|
|
|
|
|
|
|
| 214 |
# Load main image editing pipeline
|
| 215 |
pipe = QwenImageEditPipeline.from_pretrained(
|
| 216 |
"Qwen/Qwen-Image-Edit",
|
|
|
|
| 238 |
# rewriter_model = None
|
| 239 |
# torch.cuda.empty_cache()
|
| 240 |
# gc.collect()
|
|
|
|
| 241 |
@spaces.GPU()
|
| 242 |
def infer(
|
| 243 |
image,
|
|
|
|
| 250 |
num_images_per_prompt=1,
|
| 251 |
):
|
| 252 |
"""Image editing endpoint with optimized prompt handling"""
|
| 253 |
+
# Clear cache at start
|
| 254 |
+
if device == "cuda":
|
| 255 |
+
torch.cuda.empty_cache()
|
| 256 |
+
gc.collect()
|
| 257 |
+
|
| 258 |
original_prompt = prompt
|
| 259 |
prompt_info = ""
|
| 260 |
|
|
|
|
| 262 |
if rewrite_prompt:
|
| 263 |
try:
|
| 264 |
enhanced_instruction = polish_prompt(original_prompt)
|
| 265 |
+
if enhanced_instruction and enhanced_instruction != original_prompt:
|
| 266 |
+
prompt_info = (
|
| 267 |
+
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>"
|
| 268 |
+
f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>"
|
| 269 |
+
f"<p><strong>Original:</strong> {original_prompt}</p>"
|
| 270 |
+
f"<p><strong style='color:#2E7D32;'>Enhanced:</strong> {enhanced_instruction}</p>"
|
| 271 |
+
f"</div>"
|
| 272 |
+
)
|
| 273 |
+
prompt = enhanced_instruction
|
| 274 |
+
else:
|
| 275 |
+
prompt_info = (
|
| 276 |
+
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF9800; background: #fff8f0'>"
|
| 277 |
+
f"<h4 style='margin-top: 0;'>📝 Prompt Enhancement</h4>"
|
| 278 |
+
f"<p>No enhancement applied or enhancement failed</p>"
|
| 279 |
+
f"</div>"
|
| 280 |
+
)
|
| 281 |
except Exception as e:
|
| 282 |
+
print(f"Prompt enhancement error: {str(e)}") # Debug logging
|
| 283 |
gr.Warning(f"Prompt enhancement failed: {str(e)}")
|
| 284 |
prompt_info = (
|
| 285 |
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>"
|
|
|
|
| 295 |
f"</div>"
|
| 296 |
)
|
| 297 |
|
|
|
|
| 298 |
# Set seed for reproducibility
|
| 299 |
seed_val = seed if not randomize_seed else random.randint(0, MAX_SEED)
|
| 300 |
generator = torch.Generator(device=device).manual_seed(seed_val)
|
|
|
|
| 310 |
true_cfg_scale=true_guidance_scale,
|
| 311 |
num_images_per_prompt=num_images_per_prompt
|
| 312 |
).images
|
|
|
|
| 313 |
|
| 314 |
+
# Clear cache after generation
|
| 315 |
+
if device == "cuda":
|
| 316 |
+
torch.cuda.empty_cache()
|
| 317 |
+
gc.collect()
|
| 318 |
+
|
| 319 |
+
return edited_images, seed_val, prompt_info
|
| 320 |
except Exception as e:
|
| 321 |
+
# Clear cache on error
|
| 322 |
+
if device == "cuda":
|
| 323 |
+
torch.cuda.empty_cache()
|
| 324 |
+
gc.collect()
|
| 325 |
gr.Error(f"Image generation failed: {str(e)}")
|
| 326 |
return [], seed_val, (
|
| 327 |
f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #dd2c00; background: #fef5f5'>"
|
|
|
|
| 330 |
f"</div>"
|
| 331 |
)
|
| 332 |
|
|
|
|
|
|
|
| 333 |
with gr.Blocks(title="Qwen Image Editor Fast") as demo:
|
| 334 |
gr.Markdown("""
|
| 335 |
<div style="text-align: center; background: linear-gradient(to right, #3a7bd5, #00d2ff); color: white; padding: 20px; border-radius: 8px;">
|