nickdigger commited on
Commit
d0aa398
·
verified ·
1 Parent(s): c31d3ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -286
app.py CHANGED
@@ -18,22 +18,17 @@ import torch
18
  from transformers import LlavaForConditionalGeneration, AutoProcessor
19
  from PIL import Image
20
  import tempfile, gc, os, shutil, json, time, re
21
- from pathlib import Path
22
- from datetime import datetime
23
- from typing import Optional
24
  from urllib.parse import urlparse
 
25
 
26
  # ===== BUILT-IN UTILITY FUNCTIONS =====
27
  def fix_image_url(raw_url_or_path: str, host: Optional[str] = None) -> str:
28
- """Convert local image paths to URLs for export"""
29
  if not raw_url_or_path:
30
  return raw_url_or_path
31
-
32
  try:
33
  parsed = urlparse(raw_url_or_path)
34
  except Exception:
35
  parsed = None
36
-
37
  if parsed and parsed.scheme and parsed.netloc:
38
  full = raw_url_or_path
39
  if "/file=" in full and "/gradio_api/file=" not in full:
@@ -41,7 +36,6 @@ def fix_image_url(raw_url_or_path: str, host: Optional[str] = None) -> str:
41
  if "file=" in full and "/gradio_api/file=" not in full and "/gradio_api" not in full:
42
  full = full.replace("file=", "gradio_api/file=")
43
  return full
44
-
45
  if raw_url_or_path.startswith("/tmp/") or raw_url_or_path.startswith("tmp/"):
46
  if not host:
47
  return raw_url_or_path
@@ -52,18 +46,12 @@ def fix_image_url(raw_url_or_path: str, host: Optional[str] = None) -> str:
52
  if p.startswith("/"):
53
  p = p[1:]
54
  return f"{host}/gradio_api/file=/{p}"
55
-
56
  return raw_url_or_path
57
 
58
  def postprocess_caption(caption: str, max_chars: int = 1200) -> str:
59
- """Minimal caption post-processing - just basic cleanup"""
60
  if not caption or not isinstance(caption, str):
61
  return caption or ""
62
-
63
- # Only remove leading "a photo of" phrases
64
  result = re.sub(r'^(a photo of|an image of|a picture of|this is a photo of|this shows)\s*', '', caption.strip(), flags=re.IGNORECASE)
65
-
66
- # Only truncate if extremely long
67
  if max_chars and len(result) > max_chars:
68
  truncate_point = max_chars
69
  for i in range(len(result) - 1, max(0, max_chars - 100), -1):
@@ -71,27 +59,22 @@ def postprocess_caption(caption: str, max_chars: int = 1200) -> str:
71
  truncate_point = i + 1
72
  break
73
  result = result[:truncate_point].strip()
74
-
75
  if result and not result.endswith(('.', '!', '?')):
76
- result = result + "."
77
-
78
  return result
79
 
80
- # ===== CACHE CLEARING =====
81
  def force_clear_all_caches():
82
- """Force clear all possible caches"""
83
  try:
84
  if torch.cuda.is_available():
85
  torch.cuda.empty_cache()
86
  torch.cuda.synchronize()
87
  gc.collect()
88
- print("🧹 All caches cleared!")
89
  except Exception as e:
90
  print(f"⚠️ Cache clear warning: {e}")
91
 
92
  force_clear_all_caches()
93
 
94
- # ===== Setup =====
95
  _tmpdir = tempfile.gettempdir()
96
  os.environ["HF_HOME"] = os.path.join(_tmpdir, "hf_cache")
97
  os.environ["TRANSFORMERS_CACHE"] = os.path.join(_tmpdir, "transformers_cache")
@@ -148,26 +131,20 @@ DEFAULT_PROMPTS = {
148
  }
149
 
150
  def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=1200):
151
- """Generate caption using custom prompts"""
152
  try:
153
  if image is None:
154
  return "❌ No image provided"
155
-
156
  if not system_prompt.strip() or not user_prompt.strip():
157
  return "❌ Both system and user prompts are required"
158
-
159
  torch.cuda.empty_cache()
160
  gc.collect()
161
-
162
  convo = [
163
  {"role": "system", "content": system_prompt.strip()},
164
  {"role": "user", "content": user_prompt.strip()}
165
  ]
166
-
167
  convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
168
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
169
  inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
170
-
171
  with torch.no_grad():
172
  output = model.generate(
173
  **inputs,
@@ -180,167 +157,90 @@ def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=12
180
  pad_token_id=processor.tokenizer.eos_token_id,
181
  eos_token_id=processor.tokenizer.eos_token_id
182
  )
183
-
184
- if output is None or len(output) == 0:
185
  return "❌ No output generated"
186
-
187
- if 'input_ids' in inputs and len(inputs['input_ids'].shape) >= 2:
188
- input_length = inputs['input_ids'].shape[1]
189
- if len(output[0]) > input_length:
190
- generate_ids = output[0][input_length:]
191
- result = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
192
- else:
193
- result = processor.tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
194
- else:
195
- result = processor.tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
196
-
197
- result = result.strip()
198
-
199
  del inputs, output
200
  torch.cuda.empty_cache()
201
  gc.collect()
202
-
203
- final_result = postprocess_caption(result, max_chars=max_chars)
204
- return final_result if final_result else "❌ Empty result"
205
-
206
  except Exception as e:
207
  torch.cuda.empty_cache()
208
  gc.collect()
209
  return f"❌ Error: {str(e)[:200]}"
210
 
211
- # Individual caption generation functions
212
  @spaces.GPU(duration=60)
213
  @torch.no_grad()
214
  def generate_caption_1(image, system1, user1):
215
- if not image:
216
- return "❌ Upload image first"
217
  return safe_generate_caption_direct(image, system1, user1)
218
 
219
  @spaces.GPU(duration=60)
220
  @torch.no_grad()
221
  def generate_caption_2(image, system2, user2):
222
- if not image:
223
- return "❌ Upload image first"
224
  return safe_generate_caption_direct(image, system2, user2)
225
 
226
  @spaces.GPU(duration=60)
227
  @torch.no_grad()
228
  def generate_caption_3(image, system3, user3):
229
- if not image:
230
- return "❌ Upload image first"
231
  return safe_generate_caption_direct(image, system3, user3)
232
 
233
  @spaces.GPU(duration=40)
234
  @torch.no_grad()
235
  def answer_question(image, question):
236
- """Q&A function"""
237
- if not image:
238
- return "❌ Upload image first"
239
- if not question or not question.strip():
240
- return "❌ Please ask a question"
241
-
242
  try:
243
  torch.cuda.empty_cache()
244
  gc.collect()
245
-
246
- convo = [
247
- {"role": "system", "content": "You are a helpful image captioner."},
248
- {"role": "user", "content": question.strip()}
249
- ]
250
-
251
  convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
252
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
253
  inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
254
-
255
- with torch.no_grad():
256
- output = model.generate(
257
- **inputs,
258
- max_new_tokens=300,
259
- do_sample=True,
260
- temperature=0.6,
261
- top_p=0.9,
262
- top_k=None,
263
- use_cache=True,
264
- pad_token_id=processor.tokenizer.eos_token_id,
265
- eos_token_id=processor.tokenizer.eos_token_id
266
- )
267
-
268
- if 'input_ids' in inputs and len(inputs['input_ids'].shape) >= 2:
269
- input_length = inputs['input_ids'].shape[1]
270
- if len(output[0]) > input_length:
271
- generate_ids = output[0][input_length:]
272
- result = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
273
- else:
274
- result = processor.tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
275
- else:
276
- result = processor.tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
277
-
278
- result = result.strip()
279
-
280
  del inputs, output
281
  torch.cuda.empty_cache()
282
  gc.collect()
283
-
284
- final_result = postprocess_caption(result, max_chars=500)
285
- return final_result if final_result else "❌ No answer generated"
286
-
287
  except Exception as e:
288
  torch.cuda.empty_cache()
289
  gc.collect()
290
  return f"❌ Q&A Error: {str(e)[:200]}"
291
 
292
- # Helper functions for template insertion
293
  def insert_template(current_text, template_text, field_content):
294
- """Insert template at the end of current text if not already present"""
295
- if not field_content.strip():
296
- return current_text
297
-
298
  formatted_template = template_text.format(content=field_content.strip())
299
-
300
- # Check if this template is already in the text (prevent duplicates)
301
- if formatted_template in current_text:
302
- return current_text
303
-
304
- # Add template at the end with proper spacing
305
- if current_text.strip():
306
- return current_text.rstrip() + " " + formatted_template
307
- else:
308
- return formatted_template
309
 
310
  def create_template_functions():
311
- """Create template insertion functions for each button type"""
312
-
313
  def insert_key(system_text, user_text, keywords_content):
314
  template = "Pay attention to these keywords: {content}."
315
- return (
316
- insert_template(system_text, template, keywords_content),
317
- insert_template(user_text, template, keywords_content)
318
- )
319
-
320
  def insert_que(system_text, user_text, question_content):
321
  template = "Answer this question: {content}."
322
- return (
323
- insert_template(system_text, template, question_content),
324
- insert_template(user_text, template, question_content)
325
- )
326
-
327
  def insert_use(system_text, user_text, custom_content):
328
  template = "Make sure that you mention: {content}."
329
- return (
330
- insert_template(system_text, template, custom_content),
331
- insert_template(user_text, template, custom_content)
332
- )
333
-
334
  def insert_not(system_text, user_text, avoid_content):
335
  template = "Do NOT mention: {content}."
336
- return (
337
- insert_template(system_text, template, avoid_content),
338
- insert_template(user_text, template, avoid_content)
339
- )
340
-
341
  return insert_key, insert_que, insert_use, insert_not
342
 
343
- # Export function
344
  def export_joycaption_data(keywords, custom_instructions, avoid, question, cap1, cap2, cap3, qa_answer, image_path=""):
345
  try:
346
  data = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "source":"JoyCaption","data":{}}
@@ -356,184 +256,84 @@ def export_joycaption_data(keywords, custom_instructions, avoid, question, cap1,
356
  if cap2.strip(): data["data"]["caption_friendly"]=cap2.strip()
357
  if cap3.strip(): data["data"]["caption_erotic"]=cap3.strip()
358
  if qa_answer.strip(): data["data"]["qa_answer"]=qa_answer.strip()
359
- if not data["data"]:
360
- return "❌ No data to export", None
361
  js = json.dumps(data, indent=2, ensure_ascii=False)
362
  fn = f"joycaption_{time.strftime('%Y%m%d_%H%M%S')}.json"
363
  return f"✅ Exported {len(data['data'])} fields", (js, fn)
364
  except Exception as e:
365
  return f"❌ Export failed: {str(e)}", None
366
 
367
- # Create the Gradio interface
368
  with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Soft()) as demo:
369
-
370
  gr.HTML(TITLE)
371
-
372
- # Get template functions
373
  insert_key, insert_que, insert_use, insert_not = create_template_functions()
374
-
375
  with gr.Row():
376
- # Left Column - Input Fields
377
  with gr.Column(scale=1):
378
- image_input = gr.Image(type="pil", label="📸 Upload Image", height=400)
379
-
380
- keywords_input = gr.Textbox(
381
- label="🏷️ Keywords",
382
- lines=2,
383
- placeholder="Enter keywords (available as 'key' template)",
384
- info="Use 'key' button to insert into prompts"
385
- )
386
-
387
- custom_instruction_input = gr.Textbox(
388
- label="🎯 Custom Instruction",
389
- lines=2,
390
- placeholder="Enter custom instructions (available as 'use' template)",
391
- info="Use 'use' button to insert into prompts"
392
- )
393
-
394
- avoid_input = gr.Textbox(
395
- label="🚫 Avoid",
396
- lines=2,
397
- placeholder="Things to avoid mentioning (available as 'not' template)",
398
- info="Use 'not' button to insert into prompts"
399
- )
400
-
401
- question_input = gr.Textbox(
402
- label="❓ Question",
403
- lines=2,
404
- placeholder="Ask a question about the image (available as 'que' template)",
405
- info="Use 'que' button to insert into prompts"
406
- )
407
-
408
- ask_btn = gr.Button("❓ Ask Question", variant="secondary")
409
- qa_output = gr.Textbox(label="Q&A Answer", lines=4, show_copy_button=True)
410
-
411
- # Right Column - Caption Generation
412
- with gr.Column(scale=1):
413
-
414
- # Caption 1 - Casual
415
- gr.HTML("<h4 style='margin: 15px 0 10px 0; color: #374151;'>📝 Casual Caption</h4>")
416
-
417
- system1 = gr.Textbox(
418
- label="System Prompt",
419
- lines=2,
420
- value=DEFAULT_PROMPTS["casual"]["system"],
421
- placeholder="How should the AI behave?"
422
- )
423
-
424
- user1 = gr.Textbox(
425
- label="User Prompt",
426
- lines=2,
427
- value=DEFAULT_PROMPTS["casual"]["user"],
428
- placeholder="What should the AI do with this image?"
429
- )
430
-
431
- with gr.Row():
432
- key1_btn = gr.Button("key", size="sm")
433
- que1_btn = gr.Button("que", size="sm")
434
- use1_btn = gr.Button("use", size="sm")
435
- not1_btn = gr.Button("not", size="sm")
436
- gen1_btn = gr.Button("📝 Generate Casual Caption", variant="primary")
437
-
438
- out1 = gr.Textbox(lines=5, show_copy_button=True)
439
-
440
- # Caption 2 - Friendly
441
- gr.HTML("<h4 style='margin: 15px 0 10px 0; color: #374151;'>🤝 Friendly Caption</h4>")
442
-
443
- system2 = gr.Textbox(
444
- label="System Prompt",
445
- lines=2,
446
- value=DEFAULT_PROMPTS["friendly"]["system"],
447
- placeholder="How should the AI behave?"
448
- )
449
-
450
- user2 = gr.Textbox(
451
- label="User Prompt",
452
- lines=2,
453
- value=DEFAULT_PROMPTS["friendly"]["user"],
454
- placeholder="What kind of description do you want?"
455
- )
456
-
457
- with gr.Row():
458
- key2_btn = gr.Button("key", size="sm")
459
- que2_btn = gr.Button("que", size="sm")
460
- use2_btn = gr.Button("use", size="sm")
461
- not2_btn = gr.Button("not", size="sm")
462
- gen2_btn = gr.Button("🤝 Generate Friendly Caption", variant="primary")
463
-
464
- out2 = gr.Textbox(lines=5, show_copy_button=True)
465
-
466
- # Caption 3 - Erotic
467
- gr.HTML("<h4 style='margin: 15px 0 10px 0; color: #374151;'>🔥 Erotic Caption</h4>")
468
-
469
- system3 = gr.Textbox(
470
- label="System Prompt",
471
- lines=2,
472
- value=DEFAULT_PROMPTS["erotic"]["system"],
473
- placeholder="How should the AI behave?"
474
- )
475
-
476
- user3 = gr.Textbox(
477
- label="User Prompt",
478
- lines=2,
479
- value=DEFAULT_PROMPTS["erotic"]["user"],
480
- placeholder="What kind of explicit description do you want?"
481
- )
482
-
483
  with gr.Row():
484
- key3_btn = gr.Button("key", size="sm")
485
- que3_btn = gr.Button("que", size="sm")
486
- use3_btn = gr.Button("use", size="sm")
487
- not3_btn = gr.Button("not", size="sm")
488
- gen3_btn = gr.Button("🔥 Generate Erotic Caption", variant="primary")
489
-
490
- out3 = gr.Textbox(lines=5, show_copy_button=True)
491
-
492
- # Export section
493
- gr.HTML("<h4 style='margin: 20px 0 10px 0; color: #374151;'>📅 Export</h4>")
494
- export_btn = gr.Button("📅 Export All Data", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  export_out = gr.Textbox(visible=False)
496
  export_file = gr.File(visible=False)
497
-
498
- # Connect generation buttons
499
  gen1_btn.click(generate_caption_1, [image_input, system1, user1], out1)
500
  gen2_btn.click(generate_caption_2, [image_input, system2, user2], out2)
501
  gen3_btn.click(generate_caption_3, [image_input, system3, user3], out3)
502
  ask_btn.click(answer_question, [image_input, question_input], qa_output)
503
-
504
- # Template insertion buttons for Caption 1
505
- key1_btn.click(lambda s, u, k: insert_key(s, u, k), [system1, user1, keywords_input], [system1, user1])
506
- que1_btn.click(lambda s, u, q: insert_que(s, u, q), [system1, user1, question_input], [system1, user1])
507
- use1_btn.click(lambda s, u, c: insert_use(s, u, c), [system1, user1, custom_instruction_input], [system1, user1])
508
- not1_btn.click(lambda s, u, a: insert_not(s, u, a), [system1, user1, avoid_input], [system1, user1])
509
-
510
- # Template insertion buttons for Caption 2
511
- key2_btn.click(lambda s, u, k: insert_key(s, u, k), [system2, user2, keywords_input], [system2, user2])
512
- que2_btn.click(lambda s, u, q: insert_que(s, u, q), [system2, user2, question_input], [system2, user2])
513
- use2_btn.click(lambda s, u, c: insert_use(s, u, c), [system2, user2, custom_instruction_input], [system2, user2])
514
- not2_btn.click(lambda s, u, a: insert_not(s, u, a), [system2, user2, avoid_input], [system2, user2])
515
-
516
- # Template insertion buttons for Caption 3
517
- key3_btn.click(lambda s, u, k: insert_key(s, u, k), [system3, user3, keywords_input], [system3, user3])
518
- que3_btn.click(lambda s, u, q: insert_que(s, u, q), [system3, user3, question_input], [system3, user3])
519
- use3_btn.click(lambda s, u, c: insert_use(s, u, c), [system3, user3, custom_instruction_input], [system3, user3])
520
- not3_btn.click(lambda s, u, a: insert_not(s, u, a), [system3, user3, avoid_input], [system3, user3])
521
-
522
- # Export functionality
523
  def handle_export(k, c, a, q, c1, c2, c3, qa, img):
524
  msg, fd = export_joycaption_data(k, c, a, q, c1, c2, c3, qa, img)
525
  if fd:
526
  js, fn = fd
527
  p = os.path.join(tempfile.gettempdir(), fn)
528
- with open(p, "w", encoding="utf-8") as f:
529
- f.write(js)
530
  return gr.update(value=msg, visible=True), gr.update(value=p, visible=True)
531
  return gr.update(value=msg, visible=True), gr.update(visible=False)
532
-
533
- export_btn.click(handle_export, [keywords_input, custom_instruction_input, avoid_input, question_input, out1, out2, out3, qa_output, image_input], [export_out, export_file])
534
-
535
- # Simple instructions
536
- gr.HTML("<hr><h3>Instructions</h3><p>Upload an image, customize the prompts, use template buttons (key/use/not/que) to add formatted text, then generate captions.</p>")
 
537
 
538
  if __name__ == "__main__":
539
- demo.launch()
 
18
  from transformers import LlavaForConditionalGeneration, AutoProcessor
19
  from PIL import Image
20
  import tempfile, gc, os, shutil, json, time, re
 
 
 
21
  from urllib.parse import urlparse
22
+ from typing import Optional
23
 
24
  # ===== BUILT-IN UTILITY FUNCTIONS =====
25
  def fix_image_url(raw_url_or_path: str, host: Optional[str] = None) -> str:
 
26
  if not raw_url_or_path:
27
  return raw_url_or_path
 
28
  try:
29
  parsed = urlparse(raw_url_or_path)
30
  except Exception:
31
  parsed = None
 
32
  if parsed and parsed.scheme and parsed.netloc:
33
  full = raw_url_or_path
34
  if "/file=" in full and "/gradio_api/file=" not in full:
 
36
  if "file=" in full and "/gradio_api/file=" not in full and "/gradio_api" not in full:
37
  full = full.replace("file=", "gradio_api/file=")
38
  return full
 
39
  if raw_url_or_path.startswith("/tmp/") or raw_url_or_path.startswith("tmp/"):
40
  if not host:
41
  return raw_url_or_path
 
46
  if p.startswith("/"):
47
  p = p[1:]
48
  return f"{host}/gradio_api/file=/{p}"
 
49
  return raw_url_or_path
50
 
51
  def postprocess_caption(caption: str, max_chars: int = 1200) -> str:
 
52
  if not caption or not isinstance(caption, str):
53
  return caption or ""
 
 
54
  result = re.sub(r'^(a photo of|an image of|a picture of|this is a photo of|this shows)\s*', '', caption.strip(), flags=re.IGNORECASE)
 
 
55
  if max_chars and len(result) > max_chars:
56
  truncate_point = max_chars
57
  for i in range(len(result) - 1, max(0, max_chars - 100), -1):
 
59
  truncate_point = i + 1
60
  break
61
  result = result[:truncate_point].strip()
 
62
  if result and not result.endswith(('.', '!', '?')):
63
+ result += "."
 
64
  return result
65
 
 
66
  def force_clear_all_caches():
 
67
  try:
68
  if torch.cuda.is_available():
69
  torch.cuda.empty_cache()
70
  torch.cuda.synchronize()
71
  gc.collect()
 
72
  except Exception as e:
73
  print(f"⚠️ Cache clear warning: {e}")
74
 
75
  force_clear_all_caches()
76
 
77
+ # ===== SETUP =====
78
  _tmpdir = tempfile.gettempdir()
79
  os.environ["HF_HOME"] = os.path.join(_tmpdir, "hf_cache")
80
  os.environ["TRANSFORMERS_CACHE"] = os.path.join(_tmpdir, "transformers_cache")
 
131
  }
132
 
133
  def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=1200):
 
134
  try:
135
  if image is None:
136
  return "❌ No image provided"
 
137
  if not system_prompt.strip() or not user_prompt.strip():
138
  return "❌ Both system and user prompts are required"
 
139
  torch.cuda.empty_cache()
140
  gc.collect()
 
141
  convo = [
142
  {"role": "system", "content": system_prompt.strip()},
143
  {"role": "user", "content": user_prompt.strip()}
144
  ]
 
145
  convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
146
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
147
  inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
 
148
  with torch.no_grad():
149
  output = model.generate(
150
  **inputs,
 
157
  pad_token_id=processor.tokenizer.eos_token_id,
158
  eos_token_id=processor.tokenizer.eos_token_id
159
  )
160
+ if not output or len(output) == 0:
 
161
  return "❌ No output generated"
162
+ input_length = inputs['input_ids'].shape[1]
163
+ result = processor.tokenizer.decode(output[0][input_length:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
164
  del inputs, output
165
  torch.cuda.empty_cache()
166
  gc.collect()
167
+ return postprocess_caption(result, max_chars=max_chars) or "❌ Empty result"
 
 
 
168
  except Exception as e:
169
  torch.cuda.empty_cache()
170
  gc.collect()
171
  return f"❌ Error: {str(e)[:200]}"
172
 
173
+ # ===== GENERATION FUNCTIONS =====
174
  @spaces.GPU(duration=60)
175
  @torch.no_grad()
176
  def generate_caption_1(image, system1, user1):
177
+ if not image: return "❌ Upload image first"
 
178
  return safe_generate_caption_direct(image, system1, user1)
179
 
180
  @spaces.GPU(duration=60)
181
  @torch.no_grad()
182
  def generate_caption_2(image, system2, user2):
183
+ if not image: return "❌ Upload image first"
 
184
  return safe_generate_caption_direct(image, system2, user2)
185
 
186
  @spaces.GPU(duration=60)
187
  @torch.no_grad()
188
  def generate_caption_3(image, system3, user3):
189
+ if not image: return "❌ Upload image first"
 
190
  return safe_generate_caption_direct(image, system3, user3)
191
 
192
  @spaces.GPU(duration=40)
193
  @torch.no_grad()
194
  def answer_question(image, question):
195
+ if not image: return " Upload image first"
196
+ if not question.strip(): return "❌ Please ask a question"
 
 
 
 
197
  try:
198
  torch.cuda.empty_cache()
199
  gc.collect()
200
+ convo = [{"role": "system", "content": "You are a helpful image captioner."},
201
+ {"role": "user", "content": question.strip()}]
 
 
 
 
202
  convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
203
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
204
  inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
205
+ output = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.6, top_p=0.9)
206
+ input_length = inputs['input_ids'].shape[1]
207
+ result = processor.tokenizer.decode(output[0][input_length:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  del inputs, output
209
  torch.cuda.empty_cache()
210
  gc.collect()
211
+ return postprocess_caption(result, max_chars=500) or "❌ No answer generated"
 
 
 
212
  except Exception as e:
213
  torch.cuda.empty_cache()
214
  gc.collect()
215
  return f"❌ Q&A Error: {str(e)[:200]}"
216
 
217
+ # ===== TEMPLATE FUNCTIONS =====
218
  def insert_template(current_text, template_text, field_content):
219
+ if not field_content.strip(): return current_text
 
 
 
220
  formatted_template = template_text.format(content=field_content.strip())
221
+ if formatted_template in current_text: return current_text
222
+ return (current_text.rstrip() + " " + formatted_template).strip()
 
 
 
 
 
 
 
 
223
 
224
  def create_template_functions():
 
 
225
  def insert_key(system_text, user_text, keywords_content):
226
  template = "Pay attention to these keywords: {content}."
227
+ return (insert_template(system_text, template, keywords_content),
228
+ insert_template(user_text, template, keywords_content))
 
 
 
229
  def insert_que(system_text, user_text, question_content):
230
  template = "Answer this question: {content}."
231
+ return (insert_template(system_text, template, question_content),
232
+ insert_template(user_text, template, question_content))
 
 
 
233
  def insert_use(system_text, user_text, custom_content):
234
  template = "Make sure that you mention: {content}."
235
+ return (insert_template(system_text, template, custom_content),
236
+ insert_template(user_text, template, custom_content))
 
 
 
237
  def insert_not(system_text, user_text, avoid_content):
238
  template = "Do NOT mention: {content}."
239
+ return (insert_template(system_text, template, avoid_content),
240
+ insert_template(user_text, template, avoid_content))
 
 
 
241
  return insert_key, insert_que, insert_use, insert_not
242
 
243
+ # ===== EXPORT =====
244
  def export_joycaption_data(keywords, custom_instructions, avoid, question, cap1, cap2, cap3, qa_answer, image_path=""):
245
  try:
246
  data = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "source":"JoyCaption","data":{}}
 
256
  if cap2.strip(): data["data"]["caption_friendly"]=cap2.strip()
257
  if cap3.strip(): data["data"]["caption_erotic"]=cap3.strip()
258
  if qa_answer.strip(): data["data"]["qa_answer"]=qa_answer.strip()
259
+ if not data["data"]: return "❌ No data to export", None
 
260
  js = json.dumps(data, indent=2, ensure_ascii=False)
261
  fn = f"joycaption_{time.strftime('%Y%m%d_%H%M%S')}.json"
262
  return f"✅ Exported {len(data['data'])} fields", (js, fn)
263
  except Exception as e:
264
  return f"❌ Export failed: {str(e)}", None
265
 
266
+ # ===== GRADIO UI =====
267
  with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Soft()) as demo:
 
268
  gr.HTML(TITLE)
 
 
269
  insert_key, insert_que, insert_use, insert_not = create_template_functions()
270
+
271
  with gr.Row():
 
272
  with gr.Column(scale=1):
273
+ image_input = gr.Image(type="pil", label="📸 Image", height=400)
274
+ keywords_input = gr.Textbox(label="🏷️ Keywords", lines=2, placeholder="e.g. beach, sunset")
275
+ custom_instruction_input = gr.Textbox(label="🎯 Custom", lines=2, placeholder="Add extra instructions")
276
+ avoid_input = gr.Textbox(label="🚫 Avoid", lines=2, placeholder="What to avoid")
277
+ question_input = gr.Textbox(label="❓ Question", lines=2, placeholder="Ask about image")
278
+ ask_btn = gr.Button("Ask", variant="secondary")
279
+ qa_output = gr.Textbox(label="Answer", lines=3, show_copy_button=True)
280
+
281
+ gr.Markdown("---")
282
+ gr.Markdown("**Insert Template**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  with gr.Row():
284
+ key_btn = gr.Button("key", size="sm")
285
+ que_btn = gr.Button("que", size="sm")
286
+ use_btn = gr.Button("use", size="sm")
287
+ not_btn = gr.Button("not", size="sm")
288
+
289
+ with gr.Column(scale=1):
290
+ with gr.Tab("📝 Casual"):
291
+ system1 = gr.Textbox(label="System", lines=2, value=DEFAULT_PROMPTS["casual"]["system"])
292
+ user1 = gr.Textbox(label="User", lines=2, value=DEFAULT_PROMPTS["casual"]["user"])
293
+ gen1_btn = gr.Button("Generate Casual", variant="primary")
294
+ out1 = gr.Textbox(lines=5, show_copy_button=True)
295
+ with gr.Tab("🤝 Friendly"):
296
+ system2 = gr.Textbox(label="System", lines=2, value=DEFAULT_PROMPTS["friendly"]["system"])
297
+ user2 = gr.Textbox(label="User", lines=2, value=DEFAULT_PROMPTS["friendly"]["user"])
298
+ gen2_btn = gr.Button("Generate Friendly", variant="primary")
299
+ out2 = gr.Textbox(lines=5, show_copy_button=True)
300
+ with gr.Tab("🔥 Erotic"):
301
+ system3 = gr.Textbox(label="System", lines=2, value=DEFAULT_PROMPTS["erotic"]["system"])
302
+ user3 = gr.Textbox(label="User", lines=2, value=DEFAULT_PROMPTS["erotic"]["user"])
303
+ gen3_btn = gr.Button("Generate Erotic", variant="primary")
304
+ out3 = gr.Textbox(lines=5, show_copy_button=True)
305
+
306
+ gr.Markdown("---")
307
+ export_btn = gr.Button("📦 Export All", variant="secondary")
308
  export_out = gr.Textbox(visible=False)
309
  export_file = gr.File(visible=False)
310
+
311
+ # === Event Bindings ===
312
  gen1_btn.click(generate_caption_1, [image_input, system1, user1], out1)
313
  gen2_btn.click(generate_caption_2, [image_input, system2, user2], out2)
314
  gen3_btn.click(generate_caption_3, [image_input, system3, user3], out3)
315
  ask_btn.click(answer_question, [image_input, question_input], qa_output)
316
+
317
+ # Shared template bar
318
+ key_btn.click(lambda s1,u1,k: insert_key(s1,u1,k), [system1,user1,keywords_input], [system1,user1])
319
+ que_btn.click(lambda s1,u1,q: insert_que(s1,u1,q), [system1,user1,question_input], [system1,user1])
320
+ use_btn.click(lambda s1,u1,c: insert_use(s1,u1,c), [system1,user1,custom_instruction_input], [system1,user1])
321
+ not_btn.click(lambda s1,u1,a: insert_not(s1,u1,a), [system1,user1,avoid_input], [system1,user1])
322
+
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  def handle_export(k, c, a, q, c1, c2, c3, qa, img):
324
  msg, fd = export_joycaption_data(k, c, a, q, c1, c2, c3, qa, img)
325
  if fd:
326
  js, fn = fd
327
  p = os.path.join(tempfile.gettempdir(), fn)
328
+ with open(p, "w", encoding="utf-8") as f: f.write(js)
 
329
  return gr.update(value=msg, visible=True), gr.update(value=p, visible=True)
330
  return gr.update(value=msg, visible=True), gr.update(visible=False)
331
+
332
+ export_btn.click(handle_export,
333
+ [keywords_input, custom_instruction_input, avoid_input, question_input,
334
+ out1, out2, out3, qa_output, image_input],
335
+ [export_out, export_file]
336
+ )
337
 
338
  if __name__ == "__main__":
339
+ demo.launch()