nickdigger commited on
Commit
aa5f60b
Β·
verified Β·
1 Parent(s): da35d44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -123
app.py CHANGED
@@ -16,7 +16,6 @@ except Exception:
16
  import gradio as gr
17
  import torch
18
  from transformers import LlavaForConditionalGeneration, AutoProcessor
19
- from PIL import Image
20
  import tempfile, gc, os, shutil, json, time, re
21
  from urllib.parse import urlparse
22
  from typing import Optional
@@ -42,9 +41,7 @@ def fix_image_url(raw_url_or_path: str, host: Optional[str] = None) -> str:
42
  host = host.rstrip("/")
43
  if not (host.startswith("http://") or host.startswith("https://")):
44
  host = "https://" + host
45
- p = raw_url_or_path
46
- if p.startswith("/"):
47
- p = p[1:]
48
  return f"{host}/gradio_api/file=/{p}"
49
  return raw_url_or_path
50
 
@@ -69,49 +66,34 @@ def force_clear_all_caches():
69
  torch.cuda.empty_cache()
70
  torch.cuda.synchronize()
71
  gc.collect()
72
- except Exception as e:
73
- print(f"⚠️ Cache clear warning: {e}")
74
 
75
  force_clear_all_caches()
76
 
77
  # ===== SETUP =====
78
  _tmpdir = tempfile.gettempdir()
79
- os.environ["HF_HOME"] = os.path.join(_tmpdir, "hf_cache")
80
- os.environ["TRANSFORMERS_CACHE"] = os.path.join(_tmpdir, "transformers_cache")
81
- os.environ["HF_DATASETS_CACHE"] = os.path.join(_tmpdir, "datasets_cache")
82
- os.environ["TORCH_HOME"] = os.path.join(_tmpdir, "torch_cache")
 
 
 
83
 
84
  MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
85
  SPACE_HOST = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST") or None
86
 
87
- def cleanup_storage():
88
- try:
89
- for key in ["HF_HOME","TRANSFORMERS_CACHE","HF_DATASETS_CACHE","TORCH_HOME"]:
90
- d = os.environ.get(key)
91
- if d and os.path.exists(d):
92
- shutil.rmtree(d, ignore_errors=True)
93
- gc.collect()
94
- except Exception as e:
95
- print(f"⚠️ Storage cleanup warning: {e}")
96
-
97
- TITLE = """
98
- <div style="text-align:center;margin:20px 0;">
99
- <h1>🎨 JoyCaption Advanced Prompting System (v6.0)</h1>
100
- <p><strong>πŸŽ›οΈ Fully customizable prompts β€’ Template helpers β€’ Professional control</strong></p>
101
- </div><hr>
102
- """
103
-
104
  print("πŸš€ Loading JoyCaption model...")
105
  processor = AutoProcessor.from_pretrained(MODEL_PATH, cache_dir=None)
106
  model = LlavaForConditionalGeneration.from_pretrained(
107
- MODEL_PATH,
108
- torch_dtype=torch.bfloat16,
109
  device_map="auto",
110
  cache_dir=None,
111
  low_cpu_mem_usage=True
112
  )
113
  model.eval()
114
- cleanup_storage()
115
  print("βœ… Model loaded successfully!")
116
 
117
  # ===== DEFAULT PROMPTS =====
@@ -130,6 +112,7 @@ DEFAULT_PROMPTS = {
130
  }
131
  }
132
 
 
133
  def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=1200):
134
  try:
135
  if image is None:
@@ -144,7 +127,7 @@ def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=12
144
  ]
145
  convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
146
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
147
- inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
148
  with torch.no_grad():
149
  output = model.generate(
150
  **inputs,
@@ -152,14 +135,11 @@ def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=12
152
  do_sample=True,
153
  temperature=0.6,
154
  top_p=0.9,
155
- top_k=None,
156
  use_cache=True,
157
  pad_token_id=processor.tokenizer.eos_token_id,
158
- eos_token_id=processor.tokenizer.eos_token_id
159
  )
160
- if not output or len(output) == 0:
161
- return "❌ No output generated"
162
- input_length = inputs['input_ids'].shape[1]
163
  result = processor.tokenizer.decode(output[0][input_length:], skip_special_tokens=True)
164
  del inputs, output
165
  torch.cuda.empty_cache()
@@ -170,40 +150,34 @@ def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=12
170
  gc.collect()
171
  return f"❌ Error: {str(e)[:200]}"
172
 
173
- # ===== CAPTION FUNCTIONS =====
174
- @spaces.GPU(duration=60)
175
- @torch.no_grad()
176
- def generate_caption_1(image, system1, user1):
177
- if not image: return "❌ Upload image first"
178
- return safe_generate_caption_direct(image, system1, user1)
179
-
180
  @spaces.GPU(duration=60)
181
  @torch.no_grad()
182
- def generate_caption_2(image, system2, user2):
183
- if not image: return "❌ Upload image first"
184
- return safe_generate_caption_direct(image, system2, user2)
185
-
186
- @spaces.GPU(duration=60)
187
- @torch.no_grad()
188
- def generate_caption_3(image, system3, user3):
189
- if not image: return "❌ Upload image first"
190
- return safe_generate_caption_direct(image, system3, user3)
191
 
 
192
  @spaces.GPU(duration=40)
193
  @torch.no_grad()
194
  def answer_question(image, question):
195
- if not image: return "❌ Upload image first"
196
- if not question.strip(): return "❌ Please ask a question"
 
 
197
  try:
198
  torch.cuda.empty_cache()
199
  gc.collect()
200
- convo = [{"role": "system", "content": "You are a helpful image captioner."},
201
- {"role": "user", "content": question.strip()}]
 
 
202
  convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
203
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
204
- inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
205
  output = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.6, top_p=0.9)
206
- input_length = inputs['input_ids'].shape[1]
207
  result = processor.tokenizer.decode(output[0][input_length:], skip_special_tokens=True)
208
  del inputs, output
209
  torch.cuda.empty_cache()
@@ -214,72 +188,81 @@ def answer_question(image, question):
214
  gc.collect()
215
  return f"❌ Q&A Error: {str(e)[:200]}"
216
 
217
- # ===== TEMPLATE FUNCTIONS =====
218
  def insert_template(current_text, template_text, field_content):
219
- if not field_content.strip(): return current_text
220
- formatted_template = template_text.format(content=field_content.strip())
221
- if formatted_template in current_text: return current_text
222
- return (current_text.rstrip() + " " + formatted_template).strip()
 
 
223
 
224
  def create_template_functions():
225
- def insert_key(system_text, user_text, keywords_content):
226
- template = "Pay attention to these keywords: {content}."
227
- return (insert_template(system_text, template, keywords_content),
228
- insert_template(user_text, template, keywords_content))
229
- def insert_que(system_text, user_text, question_content):
230
- template = "Answer this question: {content}."
231
- return (insert_template(system_text, template, question_content),
232
- insert_template(user_text, template, question_content))
233
- def insert_use(system_text, user_text, custom_content):
234
- template = "Make sure that you mention: {content}."
235
- return (insert_template(system_text, template, custom_content),
236
- insert_template(user_text, template, custom_content))
237
- def insert_not(system_text, user_text, avoid_content):
238
- template = "Do NOT mention: {content}."
239
- return (insert_template(system_text, template, avoid_content),
240
- insert_template(user_text, template, avoid_content))
241
  return insert_key, insert_que, insert_use, insert_not
242
 
243
  # ===== EXPORT =====
244
- def export_joycaption_data(keywords, custom_instructions, avoid, question, cap1, cap2, cap3, qa_answer, image_path=""):
245
  try:
246
- data = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "source":"JoyCaption","data":{}}
247
- if keywords.strip(): data["data"]["keywords"]=keywords.strip()
248
- if custom_instructions.strip(): data["data"]["custom_instructions"]=custom_instructions.strip()
249
- if avoid.strip(): data["data"]["avoid"]=avoid.strip()
250
- if question.strip(): data["data"]["question"]=question.strip()
251
- if image_path.strip():
252
- data["data"]["image_local_path"]=image_path
253
- image_url=fix_image_url(image_path, host=(SPACE_HOST or ""))
254
- if image_url: data["data"]["image_url"]=image_url
255
- if cap1.strip(): data["data"]["caption_casual"]=cap1.strip()
256
- if cap2.strip(): data["data"]["caption_friendly"]=cap2.strip()
257
- if cap3.strip(): data["data"]["caption_erotic"]=cap3.strip()
258
- if qa_answer.strip(): data["data"]["qa_answer"]=qa_answer.strip()
259
- if not data["data"]: return "❌ No data to export", None
 
 
260
  js = json.dumps(data, indent=2, ensure_ascii=False)
261
  fn = f"joycaption_{time.strftime('%Y%m%d_%H%M%S')}.json"
262
- return f"βœ… Exported {len(data['data'])} fields", (js, fn)
263
  except Exception as e:
264
- return f"❌ Export failed: {str(e)}", None
265
 
266
  # ===== UI =====
267
  with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Soft()) as demo:
268
- gr.HTML(TITLE)
 
 
 
 
 
 
269
  insert_key, insert_que, insert_use, insert_not = create_template_functions()
 
270
 
271
  with gr.Row():
 
272
  with gr.Column(scale=1):
273
  image_input = gr.Image(type="pil", label="πŸ“Έ Image", height=400)
274
  keywords_input = gr.Textbox(label="🏷️ Keywords", lines=2, placeholder="e.g. beach, sunset")
275
- custom_instruction_input = gr.Textbox(label="🎯 Custom", lines=2, placeholder="Add extra instructions")
276
- avoid_input = gr.Textbox(label="🚫 Avoid", lines=2, placeholder="What to avoid")
277
  question_input = gr.Textbox(label="❓ Question", lines=2, placeholder="Ask about image")
278
  ask_btn = gr.Button("Ask", variant="secondary")
279
  qa_output = gr.Textbox(label="Answer", lines=3, show_copy_button=True)
280
 
 
281
  with gr.Column(scale=1):
282
- # Template buttons moved above tabs
283
  gr.Markdown("**Insert Template**")
284
  with gr.Row():
285
  key_btn = gr.Button("key", size="sm")
@@ -287,22 +270,26 @@ with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Sof
287
  use_btn = gr.Button("use", size="sm")
288
  not_btn = gr.Button("not", size="sm")
289
 
290
- # Tabs with flexible Textboxes
291
- with gr.Tab("πŸ“ Casual"):
292
- system1 = gr.Textbox(label="System", show_label=True, placeholder="", value=DEFAULT_PROMPTS["casual"]["system"], lines=1, max_lines=5)
293
- user1 = gr.Textbox(label="User", show_label=True, placeholder="", value=DEFAULT_PROMPTS["casual"]["user"], lines=1, max_lines=8)
 
 
294
  gen1_btn = gr.Button("Generate Casual", variant="primary")
295
  out1 = gr.Textbox(lines=5, show_copy_button=True)
296
-
297
- with gr.Tab("🀝 Friendly"):
298
- system2 = gr.Textbox(label="System", show_label=True, placeholder="", value=DEFAULT_PROMPTS["friendly"]["system"], lines=1, max_lines=5)
299
- user2 = gr.Textbox(label="User", show_label=True, placeholder="", value=DEFAULT_PROMPTS["friendly"]["user"], lines=1, max_lines=8)
 
300
  gen2_btn = gr.Button("Generate Friendly", variant="primary")
301
  out2 = gr.Textbox(lines=5, show_copy_button=True)
302
-
303
- with gr.Tab("πŸ”₯ Erotic"):
304
- system3 = gr.Textbox(label="System", show_label=True, placeholder="", value=DEFAULT_PROMPTS["erotic"]["system"], lines=1, max_lines=5)
305
- user3 = gr.Textbox(label="User", show_label=True, placeholder="", value=DEFAULT_PROMPTS["erotic"]["user"], lines=1, max_lines=8)
 
306
  gen3_btn = gr.Button("Generate Erotic", variant="primary")
307
  out3 = gr.Textbox(lines=5, show_copy_button=True)
308
 
@@ -311,27 +298,54 @@ with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Sof
311
  export_out = gr.Textbox(visible=False)
312
  export_file = gr.File(visible=False)
313
 
 
 
 
 
 
314
  # === EVENTS ===
315
- gen1_btn.click(generate_caption_1, [image_input, system1, user1], out1)
316
- gen2_btn.click(generate_caption_2, [image_input, system2, user2], out2)
317
- gen3_btn.click(generate_caption_3, [image_input, system3, user3], out3)
318
  ask_btn.click(answer_question, [image_input, question_input], qa_output)
319
 
320
- key_btn.click(lambda s1,u1,k: insert_key(s1,u1,k), [system1,user1,keywords_input], [system1,user1])
321
- que_btn.click(lambda s1,u1,q: insert_que(s1,u1,q), [system1,user1,question_input], [system1,user1])
322
- use_btn.click(lambda s1,u1,c: insert_use(s1,u1,c), [system1,user1,custom_instruction_input], [system1,user1])
323
- not_btn.click(lambda s1,u1,a: insert_not(s1,u1,a), [system1,user1,avoid_input], [system1,user1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
  def handle_export(k, c, a, q, c1, c2, c3, qa, img):
326
  msg, fd = export_joycaption_data(k, c, a, q, c1, c2, c3, qa, img)
327
  if fd:
328
  js, fn = fd
329
- p = os.path.join(tempfile.gettempdir(), fn)
330
- with open(p, "w", encoding="utf-8") as f: f.write(js)
331
- return gr.update(value=msg, visible=True), gr.update(value=p, visible=True)
 
332
  return gr.update(value=msg, visible=True), gr.update(visible=False)
333
 
334
- export_btn.click(handle_export,
 
335
  [keywords_input, custom_instruction_input, avoid_input, question_input,
336
  out1, out2, out3, qa_output, image_input],
337
  [export_out, export_file]
 
16
  import gradio as gr
17
  import torch
18
  from transformers import LlavaForConditionalGeneration, AutoProcessor
 
19
  import tempfile, gc, os, shutil, json, time, re
20
  from urllib.parse import urlparse
21
  from typing import Optional
 
41
  host = host.rstrip("/")
42
  if not (host.startswith("http://") or host.startswith("https://")):
43
  host = "https://" + host
44
+ p = raw_url_or_path.lstrip("/")
 
 
45
  return f"{host}/gradio_api/file=/{p}"
46
  return raw_url_or_path
47
 
 
66
  torch.cuda.empty_cache()
67
  torch.cuda.synchronize()
68
  gc.collect()
69
+ except Exception:
70
+ pass
71
 
72
  force_clear_all_caches()
73
 
74
  # ===== SETUP =====
75
  _tmpdir = tempfile.gettempdir()
76
+ for key, folder in {
77
+ "HF_HOME": "hf_cache",
78
+ "TRANSFORMERS_CACHE": "transformers_cache",
79
+ "HF_DATASETS_CACHE": "datasets_cache",
80
+ "TORCH_HOME": "torch_cache"
81
+ }.items():
82
+ os.environ[key] = os.path.join(_tmpdir, folder)
83
 
84
  MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
85
  SPACE_HOST = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST") or None
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  print("πŸš€ Loading JoyCaption model...")
88
  processor = AutoProcessor.from_pretrained(MODEL_PATH, cache_dir=None)
89
  model = LlavaForConditionalGeneration.from_pretrained(
90
+ MODEL_PATH,
91
+ torch_dtype=torch.bfloat16,
92
  device_map="auto",
93
  cache_dir=None,
94
  low_cpu_mem_usage=True
95
  )
96
  model.eval()
 
97
  print("βœ… Model loaded successfully!")
98
 
99
  # ===== DEFAULT PROMPTS =====
 
112
  }
113
  }
114
 
115
+ # ===== CORE CAPTIONING =====
116
  def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=1200):
117
  try:
118
  if image is None:
 
127
  ]
128
  convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
129
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
130
+ inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
131
  with torch.no_grad():
132
  output = model.generate(
133
  **inputs,
 
135
  do_sample=True,
136
  temperature=0.6,
137
  top_p=0.9,
 
138
  use_cache=True,
139
  pad_token_id=processor.tokenizer.eos_token_id,
140
+ eos_token_id=processor.tokenizer.eos_token_id,
141
  )
142
+ input_length = inputs["input_ids"].shape[1]
 
 
143
  result = processor.tokenizer.decode(output[0][input_length:], skip_special_tokens=True)
144
  del inputs, output
145
  torch.cuda.empty_cache()
 
150
  gc.collect()
151
  return f"❌ Error: {str(e)[:200]}"
152
 
153
+ # ===== INDIVIDUAL CAPTION WRAPPERS =====
 
 
 
 
 
 
154
  @spaces.GPU(duration=60)
155
  @torch.no_grad()
156
+ def generate_caption(image, system, user):
157
+ if not image:
158
+ return "❌ Upload image first"
159
+ return safe_generate_caption_direct(image, system, user)
 
 
 
 
 
160
 
161
+ # ===== Q&A =====
162
  @spaces.GPU(duration=40)
163
  @torch.no_grad()
164
  def answer_question(image, question):
165
+ if not image:
166
+ return "❌ Upload image first"
167
+ if not question.strip():
168
+ return "❌ Please ask a question"
169
  try:
170
  torch.cuda.empty_cache()
171
  gc.collect()
172
+ convo = [
173
+ {"role": "system", "content": "You are a helpful image captioner."},
174
+ {"role": "user", "content": question.strip()},
175
+ ]
176
  convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
177
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
178
+ inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
179
  output = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.6, top_p=0.9)
180
+ input_length = inputs["input_ids"].shape[1]
181
  result = processor.tokenizer.decode(output[0][input_length:], skip_special_tokens=True)
182
  del inputs, output
183
  torch.cuda.empty_cache()
 
188
  gc.collect()
189
  return f"❌ Q&A Error: {str(e)[:200]}"
190
 
191
+ # ===== TEMPLATE HELPERS =====
192
  def insert_template(current_text, template_text, field_content):
193
+ if not field_content.strip():
194
+ return current_text
195
+ formatted = template_text.format(content=field_content.strip())
196
+ if formatted in current_text:
197
+ return current_text
198
+ return (current_text.rstrip() + " " + formatted).strip()
199
 
200
  def create_template_functions():
201
+ def insert_key(s, u, c):
202
+ t = "Pay attention to these keywords: {content}."
203
+ return insert_template(s, t, c), insert_template(u, t, c)
204
+ def insert_que(s, u, c):
205
+ t = "Answer this question: {content}."
206
+ return insert_template(s, t, c), insert_template(u, t, c)
207
+ def insert_use(s, u, c):
208
+ t = "Make sure that you mention: {content}."
209
+ return insert_template(s, t, c), insert_template(u, t, c)
210
+ def insert_not(s, u, c):
211
+ t = "Do NOT mention: {content}."
212
+ return insert_template(s, t, c), insert_template(u, t, c)
 
 
 
 
213
  return insert_key, insert_que, insert_use, insert_not
214
 
215
  # ===== EXPORT =====
216
+ def export_joycaption_data(keywords, custom_instructions, avoid, question, c1, c2, c3, qa, img):
217
  try:
218
+ data = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "source": "JoyCaption", "data": {}}
219
+ add = data["data"]
220
+ if keywords.strip(): add["keywords"] = keywords.strip()
221
+ if custom_instructions.strip(): add["custom_instructions"] = custom_instructions.strip()
222
+ if avoid.strip(): add["avoid"] = avoid.strip()
223
+ if question.strip(): add["question"] = question.strip()
224
+ if img.strip():
225
+ add["image_local_path"] = img
226
+ url = fix_image_url(img, host=(SPACE_HOST or ""))
227
+ if url: add["image_url"] = url
228
+ if c1.strip(): add["caption_casual"] = c1.strip()
229
+ if c2.strip(): add["caption_friendly"] = c2.strip()
230
+ if c3.strip(): add["caption_erotic"] = c3.strip()
231
+ if qa.strip(): add["qa_answer"] = qa.strip()
232
+ if not add:
233
+ return "❌ No data to export", None
234
  js = json.dumps(data, indent=2, ensure_ascii=False)
235
  fn = f"joycaption_{time.strftime('%Y%m%d_%H%M%S')}.json"
236
+ return f"βœ… Exported {len(add)} fields", (js, fn)
237
  except Exception as e:
238
+ return f"❌ Export failed: {e}", None
239
 
240
  # ===== UI =====
241
  with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Soft()) as demo:
242
+ gr.HTML("""
243
+ <style>
244
+ textarea.autoresize {overflow-y:hidden!important;min-height:50px!important;height:auto!important;}
245
+ </style>
246
+ """)
247
+ gr.HTML("<h1 style='text-align:center;margin-top:10px;'>🎨 JoyCaption Advanced Prompting System (v6.0)</h1><hr>")
248
+
249
  insert_key, insert_que, insert_use, insert_not = create_template_functions()
250
+ active_tab = gr.State("casual")
251
 
252
  with gr.Row():
253
+ # LEFT
254
  with gr.Column(scale=1):
255
  image_input = gr.Image(type="pil", label="πŸ“Έ Image", height=400)
256
  keywords_input = gr.Textbox(label="🏷️ Keywords", lines=2, placeholder="e.g. beach, sunset")
257
+ custom_instruction_input = gr.Textbox(label="🎯 Custom", lines=2, placeholder="Extra instructions")
258
+ avoid_input = gr.Textbox(label="🚫 Avoid", lines=2, placeholder="Things to avoid")
259
  question_input = gr.Textbox(label="❓ Question", lines=2, placeholder="Ask about image")
260
  ask_btn = gr.Button("Ask", variant="secondary")
261
  qa_output = gr.Textbox(label="Answer", lines=3, show_copy_button=True)
262
 
263
+ # RIGHT
264
  with gr.Column(scale=1):
265
+ # template buttons top
266
  gr.Markdown("**Insert Template**")
267
  with gr.Row():
268
  key_btn = gr.Button("key", size="sm")
 
270
  use_btn = gr.Button("use", size="sm")
271
  not_btn = gr.Button("not", size="sm")
272
 
273
+ # Tabs
274
+ with gr.Tab("πŸ“ Casual") as tab1:
275
+ gr.Markdown("**System Prompt**")
276
+ system1 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["casual"]["system"], elem_classes="autoresize")
277
+ gr.Markdown("**User Prompt**")
278
+ user1 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["casual"]["user"], elem_classes="autoresize")
279
  gen1_btn = gr.Button("Generate Casual", variant="primary")
280
  out1 = gr.Textbox(lines=5, show_copy_button=True)
281
+ with gr.Tab("🀝 Friendly") as tab2:
282
+ gr.Markdown("**System Prompt**")
283
+ system2 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["friendly"]["system"], elem_classes="autoresize")
284
+ gr.Markdown("**User Prompt**")
285
+ user2 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["friendly"]["user"], elem_classes="autoresize")
286
  gen2_btn = gr.Button("Generate Friendly", variant="primary")
287
  out2 = gr.Textbox(lines=5, show_copy_button=True)
288
+ with gr.Tab("πŸ”₯ Erotic") as tab3:
289
+ gr.Markdown("**System Prompt**")
290
+ system3 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["erotic"]["system"], elem_classes="autoresize")
291
+ gr.Markdown("**User Prompt**")
292
+ user3 = gr.Textbox(show_label=False, value=DEFAULT_PROMPTS["erotic"]["user"], elem_classes="autoresize")
293
  gen3_btn = gr.Button("Generate Erotic", variant="primary")
294
  out3 = gr.Textbox(lines=5, show_copy_button=True)
295
 
 
298
  export_out = gr.Textbox(visible=False)
299
  export_file = gr.File(visible=False)
300
 
301
+ # === TAB SWITCH HANDLERS ===
302
+ tab1.select(lambda: "casual", None, active_tab)
303
+ tab2.select(lambda: "friendly", None, active_tab)
304
+ tab3.select(lambda: "erotic", None, active_tab)
305
+
306
  # === EVENTS ===
307
+ gen1_btn.click(generate_caption, [image_input, system1, user1], out1)
308
+ gen2_btn.click(generate_caption, [image_input, system2, user2], out2)
309
+ gen3_btn.click(generate_caption, [image_input, system3, user3], out3)
310
  ask_btn.click(answer_question, [image_input, question_input], qa_output)
311
 
312
+ # Template logic β€” update only current tab
313
+ def handle_template(btn_type, tab, s1, u1, s2, u2, s3, u3, k, c, q, a):
314
+ key_f, que_f, use_f, not_f = create_template_functions()
315
+ mapping = {
316
+ "key": key_f, "que": que_f, "use": use_f, "not": not_f
317
+ }
318
+ fn = mapping.get(btn_type)
319
+ if not fn:
320
+ return s1, u1, s2, u2, s3, u3
321
+ if tab == "casual":
322
+ s1, u1 = fn(s1, u1, k or c or q or a)
323
+ elif tab == "friendly":
324
+ s2, u2 = fn(s2, u2, k or c or q or a)
325
+ elif tab == "erotic":
326
+ s3, u3 = fn(s3, u3, k or c or q or a)
327
+ return s1, u1, s2, u2, s3, u3
328
+
329
+ for b, t in [(key_btn, "key"), (que_btn, "que"), (use_btn, "use"), (not_btn, "not")]:
330
+ b.click(
331
+ handle_template,
332
+ [gr.State(t), active_tab, system1, user1, system2, user2, system3, user3,
333
+ keywords_input, custom_instruction_input, question_input, avoid_input],
334
+ [system1, user1, system2, user2, system3, user3],
335
+ )
336
 
337
  def handle_export(k, c, a, q, c1, c2, c3, qa, img):
338
  msg, fd = export_joycaption_data(k, c, a, q, c1, c2, c3, qa, img)
339
  if fd:
340
  js, fn = fd
341
+ path = os.path.join(tempfile.gettempdir(), fn)
342
+ with open(path, "w", encoding="utf-8") as f:
343
+ f.write(js)
344
+ return gr.update(value=msg, visible=True), gr.update(value=path, visible=True)
345
  return gr.update(value=msg, visible=True), gr.update(visible=False)
346
 
347
+ export_btn.click(
348
+ handle_export,
349
  [keywords_input, custom_instruction_input, avoid_input, question_input,
350
  out1, out2, out3, qa_output, image_input],
351
  [export_out, export_file]