LPX55 commited on
Commit
4bd1e2d
·
verified ·
1 Parent(s): 7646ed2

Update app_local.py

Browse files
Files changed (1) hide show
  1. app_local.py +80 -56
app_local.py CHANGED
@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
28
  REWRITER_MODEL = "Qwen/Qwen1.5-4B-Chat" # Upgraded to 4B for better JSON handling
29
  dtype = torch.bfloat16
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
31
 
32
  # Quantization configuration
33
  bnb_config = BitsAndBytesConfig(
@@ -82,29 +83,40 @@ Please provide the rewritten instruction in a clean `json` format as:
82
  }
83
  '''
84
 
 
85
  def extract_json_response(model_output: str) -> str:
86
  """Extract rewritten instruction from potentially messy JSON output"""
87
- # New: Remove code block markers first
88
  model_output = re.sub(r'```(?:json)?\s*', '', model_output)
89
-
90
  try:
91
- # Try to find the JSON portion in the output
92
  start_idx = model_output.find('{')
93
  end_idx = model_output.rfind('}')
94
- if start_idx == -1 or end_idx == -1:
95
- return None
96
 
 
 
 
 
 
97
  # Expand to the full object including outer braces
98
  end_idx += 1 # Include the closing brace
99
-
100
  json_str = model_output[start_idx:end_idx]
101
 
102
- # Improved quote handling for values
103
- json_str = re.sub(r'(\w+)\s*:', r'"\1":', json_str) # Quote keys
104
- json_str = re.sub(r':\s*([^"\s{[]+)', r': "\1"', json_str) # Quote unquoted string values
105
 
106
- # Parse JSON
107
- data = json.loads(json_str)
 
 
 
 
 
 
 
 
 
 
108
 
109
  # Extract rewritten prompt from possible key variations
110
  possible_keys = [
@@ -124,20 +136,20 @@ def extract_json_response(model_output: str) -> str:
124
  for value in data.values():
125
  if isinstance(value, dict) and "Rewritten" in value:
126
  return value["Rewritten"].strip()
127
-
128
  # Try to find any string value that looks like an instruction
129
  str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500]
130
  if str_values:
131
  return str_values[0].strip()
132
-
133
  except Exception as e:
134
  print(f"JSON parse error: {str(e)}")
135
-
136
  return None
137
 
 
138
  def polish_prompt(original_prompt: str) -> str:
139
  """Enhanced prompt rewriting using original system prompt with JSON handling"""
140
- # load_rewriter()
141
 
142
  # Format as Qwen chat
143
  messages = [
@@ -156,11 +168,11 @@ def polish_prompt(original_prompt: str) -> str:
156
  with torch.no_grad():
157
  generated_ids = rewriter_model.generate(
158
  **model_inputs,
159
- max_new_tokens=256, # Reduced for better quality
160
  do_sample=True,
161
- temperature=0.5, # Less creative but more focused
162
  top_p=0.8,
163
- repetition_penalty= 1.1,
164
  no_repeat_ngram_size=3,
165
  pad_token_id=rewriter_tokenizer.eos_token_id
166
  )
@@ -171,41 +183,34 @@ def polish_prompt(original_prompt: str) -> str:
171
  skip_special_tokens=True
172
  ).strip()
173
 
174
- # New: Last-resort JSON content extraction
175
- json_str = enhanced
176
- if '```' in enhanced:
177
- parts = enhanced.split('```')
178
- if len(parts) >= 3:
179
- json_str = parts[1] # Take content between first set of ```
180
 
181
  # Try to extract JSON content
182
- rewritten_prompt = extract_json_response(json_str if '```' in enhanced else enhanced)
183
 
184
  if rewritten_prompt:
185
  # Clean up remaining artifacts
186
  rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt)
187
  rewritten_prompt = rewritten_prompt.replace('\\"', '"').replace('\\n', ' ')
188
  return rewritten_prompt
189
-
190
- # Fallback cleanup if JSON extraction fails
191
- if '```' in enhanced:
192
- # Extract content from code blocks
193
- parts = enhanced.split('```')
194
- if len(parts) >= 3:
195
- rewritten_prompt = parts[1].strip()
 
196
  else:
197
  rewritten_prompt = enhanced
198
- else:
199
- rewritten_prompt = enhanced
200
-
201
- # Improved cleaning of fallback output
202
- rewritten_prompt = re.sub(r'.*{.*}.*', '', rewritten_prompt)
203
- rewritten_prompt = re.sub(r'\s\s+', ' ', rewritten_prompt).strip()
204
- if ': ' in rewritten_prompt:
205
- rewritten_prompt = rewritten_prompt.split(': ', 1)[-1].strip()
206
-
207
- return rewritten_prompt[:200] # Ensure reasonable length
208
-
209
  # Load main image editing pipeline
210
  pipe = QwenImageEditPipeline.from_pretrained(
211
  "Qwen/Qwen-Image-Edit",
@@ -233,7 +238,6 @@ else:
233
  # rewriter_model = None
234
  # torch.cuda.empty_cache()
235
  # gc.collect()
236
-
237
  @spaces.GPU()
238
  def infer(
239
  image,
@@ -246,6 +250,11 @@ def infer(
246
  num_images_per_prompt=1,
247
  ):
248
  """Image editing endpoint with optimized prompt handling"""
 
 
 
 
 
249
  original_prompt = prompt
250
  prompt_info = ""
251
 
@@ -253,15 +262,24 @@ def infer(
253
  if rewrite_prompt:
254
  try:
255
  enhanced_instruction = polish_prompt(original_prompt)
256
- prompt_info = (
257
- f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>"
258
- f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>"
259
- f"<p><strong>Original:</strong> {original_prompt}</p>"
260
- f"<p><strong style='color:#2E7D32;'>Enhanced:</strong> {enhanced_instruction}</p>"
261
- f"</div>"
262
- )
263
- prompt = enhanced_instruction
 
 
 
 
 
 
 
 
264
  except Exception as e:
 
265
  gr.Warning(f"Prompt enhancement failed: {str(e)}")
266
  prompt_info = (
267
  f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>"
@@ -277,7 +295,6 @@ def infer(
277
  f"</div>"
278
  )
279
 
280
-
281
  # Set seed for reproducibility
282
  seed_val = seed if not randomize_seed else random.randint(0, MAX_SEED)
283
  generator = torch.Generator(device=device).manual_seed(seed_val)
@@ -293,9 +310,18 @@ def infer(
293
  true_cfg_scale=true_guidance_scale,
294
  num_images_per_prompt=num_images_per_prompt
295
  ).images
296
- return edited_images, seed_val, prompt_info
297
 
 
 
 
 
 
 
298
  except Exception as e:
 
 
 
 
299
  gr.Error(f"Image generation failed: {str(e)}")
300
  return [], seed_val, (
301
  f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #dd2c00; background: #fef5f5'>"
@@ -304,8 +330,6 @@ def infer(
304
  f"</div>"
305
  )
306
 
307
- MAX_SEED = np.iinfo(np.int32).max
308
-
309
  with gr.Blocks(title="Qwen Image Editor Fast") as demo:
310
  gr.Markdown("""
311
  <div style="text-align: center; background: linear-gradient(to right, #3a7bd5, #00d2ff); color: white; padding: 20px; border-radius: 8px;">
 
28
  REWRITER_MODEL = "Qwen/Qwen1.5-4B-Chat" # Upgraded to 4B for better JSON handling
29
  dtype = torch.bfloat16
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ MAX_SEED = np.iinfo(np.int32).max
32
 
33
  # Quantization configuration
34
  bnb_config = BitsAndBytesConfig(
 
83
  }
84
  '''
85
 
86
+
87
  def extract_json_response(model_output: str) -> str:
88
  """Extract rewritten instruction from potentially messy JSON output"""
89
+ # Remove code block markers first
90
  model_output = re.sub(r'```(?:json)?\s*', '', model_output)
 
91
  try:
92
+ # Find the JSON portion in the output
93
  start_idx = model_output.find('{')
94
  end_idx = model_output.rfind('}')
 
 
95
 
96
+ # Fix the condition - check if brackets were found
97
+ if start_idx == -1 or end_idx == -1 or start_idx >= end_idx:
98
+ print(f"No valid JSON structure found in output. Start: {start_idx}, End: {end_idx}")
99
+ return None
100
+
101
  # Expand to the full object including outer braces
102
  end_idx += 1 # Include the closing brace
 
103
  json_str = model_output[start_idx:end_idx]
104
 
105
+ # Handle potential markdown or other formatting
106
+ json_str = json_str.strip()
 
107
 
108
+ # Try to parse JSON directly first
109
+ try:
110
+ data = json.loads(json_str)
111
+ except json.JSONDecodeError as e:
112
+ print(f"Direct JSON parsing failed: {e}")
113
+ # If direct parsing fails, try cleanup
114
+ # Quote keys properly
115
+ json_str = re.sub(r'([^{}[\],\s"]+)(?=\s*:)', r'"\1"', json_str)
116
+ # Remove any trailing commas that might cause issues
117
+ json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
118
+ # Try parsing again
119
+ data = json.loads(json_str)
120
 
121
  # Extract rewritten prompt from possible key variations
122
  possible_keys = [
 
136
  for value in data.values():
137
  if isinstance(value, dict) and "Rewritten" in value:
138
  return value["Rewritten"].strip()
139
+
140
  # Try to find any string value that looks like an instruction
141
  str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500]
142
  if str_values:
143
  return str_values[0].strip()
144
+
145
  except Exception as e:
146
  print(f"JSON parse error: {str(e)}")
147
+ print(f"Model output was: {model_output}")
148
  return None
149
 
150
+
151
  def polish_prompt(original_prompt: str) -> str:
152
  """Enhanced prompt rewriting using original system prompt with JSON handling"""
 
153
 
154
  # Format as Qwen chat
155
  messages = [
 
168
  with torch.no_grad():
169
  generated_ids = rewriter_model.generate(
170
  **model_inputs,
171
+ max_new_tokens=256,
172
  do_sample=True,
173
+ temperature=0.5,
174
  top_p=0.8,
175
+ repetition_penalty=1.1,
176
  no_repeat_ngram_size=3,
177
  pad_token_id=rewriter_tokenizer.eos_token_id
178
  )
 
183
  skip_special_tokens=True
184
  ).strip()
185
 
186
+ print(f"Model raw output: {enhanced}") # Debug logging
 
 
 
 
 
187
 
188
  # Try to extract JSON content
189
+ rewritten_prompt = extract_json_response(enhanced)
190
 
191
  if rewritten_prompt:
192
  # Clean up remaining artifacts
193
  rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt)
194
  rewritten_prompt = rewritten_prompt.replace('\\"', '"').replace('\\n', ' ')
195
  return rewritten_prompt
196
+ else:
197
+ # Fallback: try to extract from code blocks or just return cleaned content
198
+ if '```' in enhanced:
199
+ parts = enhanced.split('```')
200
+ if len(parts) >= 2:
201
+ rewritten_prompt = parts[1].strip()
202
+ else:
203
+ rewritten_prompt = enhanced
204
  else:
205
  rewritten_prompt = enhanced
206
+
207
+ # Basic cleanup
208
+ rewritten_prompt = re.sub(r'\s\s+', ' ', rewritten_prompt).strip()
209
+ if ': ' in rewritten_prompt:
210
+ rewritten_prompt = rewritten_prompt.split(': ', 1)[-1].strip()
211
+
212
+ return rewritten_prompt[:200] if rewritten_prompt else original_prompt
213
+
 
 
 
214
  # Load main image editing pipeline
215
  pipe = QwenImageEditPipeline.from_pretrained(
216
  "Qwen/Qwen-Image-Edit",
 
238
  # rewriter_model = None
239
  # torch.cuda.empty_cache()
240
  # gc.collect()
 
241
  @spaces.GPU()
242
  def infer(
243
  image,
 
250
  num_images_per_prompt=1,
251
  ):
252
  """Image editing endpoint with optimized prompt handling"""
253
+ # Clear cache at start
254
+ if device == "cuda":
255
+ torch.cuda.empty_cache()
256
+ gc.collect()
257
+
258
  original_prompt = prompt
259
  prompt_info = ""
260
 
 
262
  if rewrite_prompt:
263
  try:
264
  enhanced_instruction = polish_prompt(original_prompt)
265
+ if enhanced_instruction and enhanced_instruction != original_prompt:
266
+ prompt_info = (
267
+ f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>"
268
+ f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>"
269
+ f"<p><strong>Original:</strong> {original_prompt}</p>"
270
+ f"<p><strong style='color:#2E7D32;'>Enhanced:</strong> {enhanced_instruction}</p>"
271
+ f"</div>"
272
+ )
273
+ prompt = enhanced_instruction
274
+ else:
275
+ prompt_info = (
276
+ f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF9800; background: #fff8f0'>"
277
+ f"<h4 style='margin-top: 0;'>📝 Prompt Enhancement</h4>"
278
+ f"<p>No enhancement applied or enhancement failed</p>"
279
+ f"</div>"
280
+ )
281
  except Exception as e:
282
+ print(f"Prompt enhancement error: {str(e)}") # Debug logging
283
  gr.Warning(f"Prompt enhancement failed: {str(e)}")
284
  prompt_info = (
285
  f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>"
 
295
  f"</div>"
296
  )
297
 
 
298
  # Set seed for reproducibility
299
  seed_val = seed if not randomize_seed else random.randint(0, MAX_SEED)
300
  generator = torch.Generator(device=device).manual_seed(seed_val)
 
310
  true_cfg_scale=true_guidance_scale,
311
  num_images_per_prompt=num_images_per_prompt
312
  ).images
 
313
 
314
+ # Clear cache after generation
315
+ if device == "cuda":
316
+ torch.cuda.empty_cache()
317
+ gc.collect()
318
+
319
+ return edited_images, seed_val, prompt_info
320
  except Exception as e:
321
+ # Clear cache on error
322
+ if device == "cuda":
323
+ torch.cuda.empty_cache()
324
+ gc.collect()
325
  gr.Error(f"Image generation failed: {str(e)}")
326
  return [], seed_val, (
327
  f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #dd2c00; background: #fef5f5'>"
 
330
  f"</div>"
331
  )
332
 
 
 
333
  with gr.Blocks(title="Qwen Image Editor Fast") as demo:
334
  gr.Markdown("""
335
  <div style="text-align: center; background: linear-gradient(to right, #3a7bd5, #00d2ff); color: white; padding: 20px; border-radius: 8px;">