nickdigger commited on
Commit
b693e6d
Β·
verified Β·
1 Parent(s): 951b327

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -52
app.py CHANGED
@@ -45,7 +45,6 @@ def fix_image_url(raw_url_or_path: str, host: Optional[str] = None) -> str:
45
  return f"{host}/gradio_api/file=/{p}"
46
  return raw_url_or_path
47
 
48
-
49
  def postprocess_caption(caption: str, max_chars: int = 1200) -> str:
50
  if not caption or not isinstance(caption, str):
51
  return caption or ""
@@ -61,7 +60,6 @@ def postprocess_caption(caption: str, max_chars: int = 1200) -> str:
61
  result += "."
62
  return result
63
 
64
-
65
  def force_clear_all_caches():
66
  try:
67
  if torch.cuda.is_available():
@@ -71,7 +69,6 @@ def force_clear_all_caches():
71
  except Exception:
72
  pass
73
 
74
-
75
  force_clear_all_caches()
76
 
77
  # ===== SETUP =====
@@ -115,22 +112,27 @@ DEFAULT_PROMPTS = {
115
  }
116
  }
117
 
118
- # ===== CAPTION GENERATION =====
119
  def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=1200):
 
120
  try:
121
  if image is None:
122
  return "❌ No image provided"
 
123
  if not system_prompt.strip() or not user_prompt.strip():
124
  return "❌ Both system and user prompts are required"
 
125
  torch.cuda.empty_cache()
126
  gc.collect()
 
127
  convo = [
128
  {"role": "system", "content": system_prompt.strip()},
129
  {"role": "user", "content": user_prompt.strip()}
130
  ]
 
131
  convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
132
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
133
- inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
 
134
  with torch.no_grad():
135
  output = model.generate(
136
  **inputs,
@@ -138,22 +140,39 @@ def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=12
138
  do_sample=True,
139
  temperature=0.6,
140
  top_p=0.9,
 
141
  use_cache=True,
142
  pad_token_id=processor.tokenizer.eos_token_id,
143
- eos_token_id=processor.tokenizer.eos_token_id,
144
  )
145
- input_length = inputs["input_ids"].shape[1]
146
- result = processor.tokenizer.decode(output[0][input_length:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  del inputs, output
148
  torch.cuda.empty_cache()
149
  gc.collect()
150
- return postprocess_caption(result, max_chars=max_chars) or "❌ Empty result"
 
 
 
151
  except Exception as e:
152
  torch.cuda.empty_cache()
153
  gc.collect()
154
  return f"❌ Error: {str(e)[:200]}"
155
 
156
-
157
  @spaces.GPU(duration=60)
158
  @torch.no_grad()
159
  def generate_caption(image, system, user):
@@ -161,7 +180,6 @@ def generate_caption(image, system, user):
161
  return "❌ Upload image first"
162
  return safe_generate_caption_direct(image, system, user)
163
 
164
-
165
  # ===== Q&A =====
166
  @spaces.GPU(duration=40)
167
  @torch.no_grad()
@@ -180,7 +198,8 @@ def answer_question(image, question):
180
  convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
181
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
182
  inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
183
- output = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.6, top_p=0.9)
 
184
  input_length = inputs["input_ids"].shape[1]
185
  result = processor.tokenizer.decode(output[0][input_length:], skip_special_tokens=True)
186
  del inputs, output
@@ -192,7 +211,6 @@ def answer_question(image, question):
192
  gc.collect()
193
  return f"❌ Q&A Error: {str(e)[:200]}"
194
 
195
-
196
  # ===== TEMPLATE HELPERS =====
197
  def insert_template(current_text, template_text, field_content):
198
  if not field_content.strip():
@@ -202,7 +220,6 @@ def insert_template(current_text, template_text, field_content):
202
  return current_text
203
  return (current_text.rstrip() + " " + formatted).strip()
204
 
205
-
206
  def create_template_functions():
207
  def insert_key(s, u, c):
208
  t = "Pay attention to these keywords: {content}."
@@ -218,7 +235,6 @@ def create_template_functions():
218
  return insert_template(s, t, c), insert_template(u, t, c)
219
  return insert_key, insert_que, insert_use, insert_not
220
 
221
-
222
  # ===== EXPORT =====
223
  def export_joycaption_data(keywords, custom_instructions, avoid, question, c1, c2, c3, qa, img):
224
  try:
@@ -228,9 +244,9 @@ def export_joycaption_data(keywords, custom_instructions, avoid, question, c1, c
228
  if custom_instructions.strip(): add["custom_instructions"] = custom_instructions.strip()
229
  if avoid.strip(): add["avoid"] = avoid.strip()
230
  if question.strip(): add["question"] = question.strip()
231
- if img.strip():
232
- add["image_local_path"] = img
233
- url = fix_image_url(img, host=(SPACE_HOST or ""))
234
  if url: add["image_url"] = url
235
  if c1.strip(): add["caption_casual"] = c1.strip()
236
  if c2.strip(): add["caption_friendly"] = c2.strip()
@@ -244,10 +260,9 @@ def export_joycaption_data(keywords, custom_instructions, avoid, question, c1, c
244
  except Exception as e:
245
  return f"❌ Export failed: {e}", None
246
 
247
-
248
  # ===== UI =====
249
  with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Soft()) as demo:
250
- gr.HTML("<style>textarea{overflow-y:hidden!important;}</style>")
251
  gr.HTML("<h1 style='text-align:center;margin-top:10px;'>🎨 JoyCaption Advanced Prompting System (v6.0)</h1><hr>")
252
 
253
  insert_key, insert_que, insert_use, insert_not = create_template_functions()
@@ -273,25 +288,25 @@ with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Sof
273
 
274
  with gr.Tab("πŸ“ Casual") as tab1:
275
  gr.Markdown("**System Prompt**")
276
- system1 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["casual"]["system"])
277
  gr.Markdown("**User Prompt**")
278
- user1 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["casual"]["user"])
279
  gen1_btn = gr.Button("Generate Casual", variant="primary")
280
  out1 = gr.Textbox(lines=5, show_copy_button=True)
281
 
282
  with gr.Tab("🀝 Friendly") as tab2:
283
  gr.Markdown("**System Prompt**")
284
- system2 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["friendly"]["system"])
285
  gr.Markdown("**User Prompt**")
286
- user2 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["friendly"]["user"])
287
  gen2_btn = gr.Button("Generate Friendly", variant="primary")
288
  out2 = gr.Textbox(lines=5, show_copy_button=True)
289
 
290
  with gr.Tab("πŸ”₯ Erotic") as tab3:
291
  gr.Markdown("**System Prompt**")
292
- system3 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["erotic"]["system"])
293
  gr.Markdown("**User Prompt**")
294
- user3 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["erotic"]["user"])
295
  gen3_btn = gr.Button("Generate Erotic", variant="primary")
296
  out3 = gr.Textbox(lines=5, show_copy_button=True)
297
 
@@ -300,37 +315,59 @@ with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Sof
300
  export_out = gr.Textbox(visible=False)
301
  export_file = gr.File(visible=False)
302
 
 
303
  tab1.select(lambda: "casual", None, active_tab)
304
  tab2.select(lambda: "friendly", None, active_tab)
305
  tab3.select(lambda: "erotic", None, active_tab)
306
 
 
307
  gen1_btn.click(generate_caption, [image_input, system1, user1], out1)
308
  gen2_btn.click(generate_caption, [image_input, system2, user2], out2)
309
  gen3_btn.click(generate_caption, [image_input, system3, user3], out3)
310
  ask_btn.click(answer_question, [image_input, question_input], qa_output)
311
 
312
- def handle_template(btn_type, tab, s1, u1, s2, u2, s3, u3, k, c, q, a):
 
313
  key_f, que_f, use_f, not_f = create_template_functions()
314
- mapping = {"key": key_f, "que": que_f, "use": use_f, "not": not_f}
315
- fn = mapping.get(btn_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  if not fn:
317
  return s1, u1, s2, u2, s3, u3
 
 
318
  if tab == "casual":
319
- s1, u1 = fn(s1, u1, k or c or q or a)
320
  elif tab == "friendly":
321
- s2, u2 = fn(s2, u2, k or c or q or a)
322
  elif tab == "erotic":
323
- s3, u3 = fn(s3, u3, k or c or q or a)
 
324
  return s1, u1, s2, u2, s3, u3
325
 
326
- for b, t in [(key_btn, "key"), (que_btn, "que"), (use_btn, "use"), (not_btn, "not")]:
327
- b.click(
 
328
  handle_template,
329
- [gr.State(t), active_tab, system1, user1, system2, user2, system3, user3,
330
  keywords_input, custom_instruction_input, question_input, avoid_input],
331
  [system1, user1, system2, user2, system3, user3],
332
  )
333
 
 
334
  def handle_export(k, c, a, q, c1, c2, c3, qa, img):
335
  msg, fd = export_joycaption_data(k, c, a, q, c1, c2, c3, qa, img)
336
  if fd:
@@ -348,21 +385,5 @@ with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Sof
348
  [export_out, export_file]
349
  )
350
 
351
- # JS autoresize fix for all tabs
352
- demo.load(js="""
353
- () => {
354
- function resizeAll() {
355
- document.querySelectorAll('textarea').forEach(t=>{
356
- t.style.height='auto';
357
- t.style.height=(t.scrollHeight+5)+'px';
358
- });
359
- }
360
- resizeAll();
361
- document.querySelectorAll('[role="tab"]').forEach(tab=>{
362
- tab.addEventListener('click', ()=>setTimeout(resizeAll,300));
363
- });
364
- }
365
- """)
366
-
367
  if __name__ == "__main__":
368
- demo.launch()
 
45
  return f"{host}/gradio_api/file=/{p}"
46
  return raw_url_or_path
47
 
 
48
  def postprocess_caption(caption: str, max_chars: int = 1200) -> str:
49
  if not caption or not isinstance(caption, str):
50
  return caption or ""
 
60
  result += "."
61
  return result
62
 
 
63
  def force_clear_all_caches():
64
  try:
65
  if torch.cuda.is_available():
 
69
  except Exception:
70
  pass
71
 
 
72
  force_clear_all_caches()
73
 
74
  # ===== SETUP =====
 
112
  }
113
  }
114
 
 
115
  def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=1200):
116
+ """Generate caption using custom prompts"""
117
  try:
118
  if image is None:
119
  return "❌ No image provided"
120
+
121
  if not system_prompt.strip() or not user_prompt.strip():
122
  return "❌ Both system and user prompts are required"
123
+
124
  torch.cuda.empty_cache()
125
  gc.collect()
126
+
127
  convo = [
128
  {"role": "system", "content": system_prompt.strip()},
129
  {"role": "user", "content": user_prompt.strip()}
130
  ]
131
+
132
  convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
133
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
134
+ inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
135
+
136
  with torch.no_grad():
137
  output = model.generate(
138
  **inputs,
 
140
  do_sample=True,
141
  temperature=0.6,
142
  top_p=0.9,
143
+ top_k=None,
144
  use_cache=True,
145
  pad_token_id=processor.tokenizer.eos_token_id,
146
+ eos_token_id=processor.tokenizer.eos_token_id
147
  )
148
+
149
+ if output is None or len(output) == 0:
150
+ return "❌ No output generated"
151
+
152
+ if 'input_ids' in inputs and len(inputs['input_ids'].shape) >= 2:
153
+ input_length = inputs['input_ids'].shape[1]
154
+ if len(output[0]) > input_length:
155
+ generate_ids = output[0][input_length:]
156
+ result = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
157
+ else:
158
+ result = processor.tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
159
+ else:
160
+ result = processor.tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
161
+
162
+ result = result.strip()
163
+
164
  del inputs, output
165
  torch.cuda.empty_cache()
166
  gc.collect()
167
+
168
+ final_result = postprocess_caption(result, max_chars=max_chars)
169
+ return final_result if final_result else "❌ Empty result"
170
+
171
  except Exception as e:
172
  torch.cuda.empty_cache()
173
  gc.collect()
174
  return f"❌ Error: {str(e)[:200]}"
175
 
 
176
  @spaces.GPU(duration=60)
177
  @torch.no_grad()
178
  def generate_caption(image, system, user):
 
180
  return "❌ Upload image first"
181
  return safe_generate_caption_direct(image, system, user)
182
 
 
183
  # ===== Q&A =====
184
  @spaces.GPU(duration=40)
185
  @torch.no_grad()
 
198
  convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
199
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
200
  inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
201
+ with torch.no_grad():
202
+ output = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.6, top_p=0.9)
203
  input_length = inputs["input_ids"].shape[1]
204
  result = processor.tokenizer.decode(output[0][input_length:], skip_special_tokens=True)
205
  del inputs, output
 
211
  gc.collect()
212
  return f"❌ Q&A Error: {str(e)[:200]}"
213
 
 
214
  # ===== TEMPLATE HELPERS =====
215
  def insert_template(current_text, template_text, field_content):
216
  if not field_content.strip():
 
220
  return current_text
221
  return (current_text.rstrip() + " " + formatted).strip()
222
 
 
223
  def create_template_functions():
224
  def insert_key(s, u, c):
225
  t = "Pay attention to these keywords: {content}."
 
235
  return insert_template(s, t, c), insert_template(u, t, c)
236
  return insert_key, insert_que, insert_use, insert_not
237
 
 
238
  # ===== EXPORT =====
239
  def export_joycaption_data(keywords, custom_instructions, avoid, question, c1, c2, c3, qa, img):
240
  try:
 
244
  if custom_instructions.strip(): add["custom_instructions"] = custom_instructions.strip()
245
  if avoid.strip(): add["avoid"] = avoid.strip()
246
  if question.strip(): add["question"] = question.strip()
247
+ if hasattr(img, '__str__') and str(img).strip():
248
+ add["image_local_path"] = str(img)
249
+ url = fix_image_url(str(img), host=(SPACE_HOST or ""))
250
  if url: add["image_url"] = url
251
  if c1.strip(): add["caption_casual"] = c1.strip()
252
  if c2.strip(): add["caption_friendly"] = c2.strip()
 
260
  except Exception as e:
261
  return f"❌ Export failed: {e}", None
262
 
 
263
  # ===== UI =====
264
  with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Soft()) as demo:
265
+ gr.HTML("<style>textarea{resize:none!important;}</style>")
266
  gr.HTML("<h1 style='text-align:center;margin-top:10px;'>🎨 JoyCaption Advanced Prompting System (v6.0)</h1><hr>")
267
 
268
  insert_key, insert_que, insert_use, insert_not = create_template_functions()
 
288
 
289
  with gr.Tab("πŸ“ Casual") as tab1:
290
  gr.Markdown("**System Prompt**")
291
+ system1 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["casual"]["system"], lines=3)
292
  gr.Markdown("**User Prompt**")
293
+ user1 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["casual"]["user"], lines=3)
294
  gen1_btn = gr.Button("Generate Casual", variant="primary")
295
  out1 = gr.Textbox(lines=5, show_copy_button=True)
296
 
297
  with gr.Tab("🀝 Friendly") as tab2:
298
  gr.Markdown("**System Prompt**")
299
+ system2 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["friendly"]["system"], lines=3)
300
  gr.Markdown("**User Prompt**")
301
+ user2 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["friendly"]["user"], lines=3)
302
  gen2_btn = gr.Button("Generate Friendly", variant="primary")
303
  out2 = gr.Textbox(lines=5, show_copy_button=True)
304
 
305
  with gr.Tab("πŸ”₯ Erotic") as tab3:
306
  gr.Markdown("**System Prompt**")
307
+ system3 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["erotic"]["system"], lines=3)
308
  gr.Markdown("**User Prompt**")
309
+ user3 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["erotic"]["user"], lines=3)
310
  gen3_btn = gr.Button("Generate Erotic", variant="primary")
311
  out3 = gr.Textbox(lines=5, show_copy_button=True)
312
 
 
315
  export_out = gr.Textbox(visible=False)
316
  export_file = gr.File(visible=False)
317
 
318
+ # Tab selection tracking
319
  tab1.select(lambda: "casual", None, active_tab)
320
  tab2.select(lambda: "friendly", None, active_tab)
321
  tab3.select(lambda: "erotic", None, active_tab)
322
 
323
+ # Caption generation
324
  gen1_btn.click(generate_caption, [image_input, system1, user1], out1)
325
  gen2_btn.click(generate_caption, [image_input, system2, user2], out2)
326
  gen3_btn.click(generate_caption, [image_input, system3, user3], out3)
327
  ask_btn.click(answer_question, [image_input, question_input], qa_output)
328
 
329
+ # Template insertion with proper field mapping
330
+ def handle_template(btn_type, tab, s1, u1, s2, u2, s3, u3, keywords, custom, question, avoid):
331
  key_f, que_f, use_f, not_f = create_template_functions()
332
+
333
+ # Map button type to field content
334
+ content_map = {
335
+ "key": keywords,
336
+ "que": question,
337
+ "use": custom,
338
+ "not": avoid
339
+ }
340
+
341
+ content = content_map.get(btn_type, "")
342
+ if not content.strip():
343
+ return s1, u1, s2, u2, s3, u3
344
+
345
+ # Map button type to function
346
+ fn_map = {"key": key_f, "que": que_f, "use": use_f, "not": not_f}
347
+ fn = fn_map.get(btn_type)
348
  if not fn:
349
  return s1, u1, s2, u2, s3, u3
350
+
351
+ # Apply to correct tab
352
  if tab == "casual":
353
+ s1, u1 = fn(s1, u1, content)
354
  elif tab == "friendly":
355
+ s2, u2 = fn(s2, u2, content)
356
  elif tab == "erotic":
357
+ s3, u3 = fn(s3, u3, content)
358
+
359
  return s1, u1, s2, u2, s3, u3
360
 
361
+ # Connect template buttons
362
+ for btn, btn_type in [(key_btn, "key"), (que_btn, "que"), (use_btn, "use"), (not_btn, "not")]:
363
+ btn.click(
364
  handle_template,
365
+ [gr.State(btn_type), active_tab, system1, user1, system2, user2, system3, user3,
366
  keywords_input, custom_instruction_input, question_input, avoid_input],
367
  [system1, user1, system2, user2, system3, user3],
368
  )
369
 
370
+ # Export functionality
371
  def handle_export(k, c, a, q, c1, c2, c3, qa, img):
372
  msg, fd = export_joycaption_data(k, c, a, q, c1, c2, c3, qa, img)
373
  if fd:
 
385
  [export_out, export_file]
386
  )
387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  if __name__ == "__main__":
389
+ demo.launch()