nickdigger commited on
Commit
2f41a1f
Β·
verified Β·
1 Parent(s): 656a9c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +457 -55
app.py CHANGED
@@ -1,88 +1,490 @@
1
- """
2
- Copy of the full `app.py` into the deploy folder for direct upload.
3
- This file is a snapshot of the application's main entrypoint and should be
4
- identical to the root `app.py` when uploading to Hugging Face Spaces.
5
- """
6
-
7
  try:
8
  import spaces
9
- # Ensure spaces.GPU exists and is a decorator
10
- return f
 
 
 
 
 
 
 
 
11
  return _wrap
12
  spaces.GPU = _spaces_gpu
13
 
 
 
 
 
14
 
 
 
 
 
 
 
15
 
 
 
 
 
 
 
16
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
 
 
 
 
 
 
 
19
 
20
- import gradio as gr
21
- import torch
22
- from transformers import LlavaForConditionalGeneration, AutoProcessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  r'^(a photo of|an image of|a picture of|this is a photo of|this shows)\s*': '',
24
-
25
- # Nudity precision corrections
26
- r'\\btopless women\\b': lambda m: 'nude women' if 'naked' in text.lower() or 'nude' in text.lower() else 'topless women',
27
- r'\\btopless woman\\b': lambda m: 'nude woman' if 'naked' in text.lower() or 'nude' in text.lower() else 'topless woman',
28
-
29
- # Person count corrections
30
- r'\\bthree women\\b': lambda m: 'two women' if text.count('woman') + text.count('female') <= 2 else 'three women',
31
- r'\\bfour women\\b': lambda m: 'three women' if text.count('woman') + text.count('female') <= 3 else 'four women',
32
-
33
- # Clothing precision
34
- r'\\bwearing nothing\\b': 'nude',
35
- r'\\bnot wearing.*clothes\\b': 'nude',
36
- r'\\bcompletely naked\\b': 'nude',
37
- r'\\bfully nude\\b': 'nude',
38
  }
39
-
40
- corrected_text = text
41
-
42
- // Get all textareas and inputs from the page
43
- const allInputs = document.querySelectorAll('textarea, input[type="text"]');
44
-
45
- allInputs.forEach((field, index) => {
46
- const placeholder = (field.placeholder || '').toLowerCase();
47
- const value = field.value ? field.value.trim() : '';
48
- interactive=True,
49
- placeholder="Click the button above to generate engaging caption..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
 
52
- # Casual Friend caption
53
  with gr.Row():
54
- with gr.Column(scale=4):
 
 
 
 
 
 
55
  interactive=True,
56
- placeholder="Click the button above to generate casual friend caption..."
57
  )
58
 
59
- # NSFW section removed - caused hallucination
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Keywords caption
62
  with gr.Row():
63
- with gr.Column(scale=4):
 
 
 
 
 
 
 
64
  interactive=True,
65
- placeholder="Click the button above to generate keywords caption..."
66
  )
67
 
68
- # Body Parts Focus section removed - caused hallucination
 
 
 
69
 
70
- # Descriptive text removed for cleaner interface
 
 
 
 
 
 
71
 
72
- # Export functionality
73
  with gr.Row():
74
- export_btn = gr.Button(
75
- )
76
-
77
- # NSFW button handler removed
 
 
 
 
 
 
 
 
 
 
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  generate_uncensored_btn.click(
80
  generate_uncensored_keywords_only,
81
  inputs=[image_input, keywords_input, custom_instruction_input],
 
 
 
 
 
 
 
 
 
 
82
  )
83
-
84
- # Body Parts Focus button handler removed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # Individual reload buttons - using direct generation for consistency
87
- def reload_engaging_fn(image, custom_instruction):
88
- return safe_generate_caption_direct(image, "engaging", custom_instruction=custom_instruction) if image else "❌ Upload image first"
 
 
 
 
 
 
 
1
  try:
2
  import spaces
3
+ if not hasattr(spaces, "GPU"):
4
+ def _spaces_gpu(*args, **kwargs):
5
+ def _wrap(f): return f
6
+ return _wrap
7
+ spaces.GPU = _spaces_gpu
8
+ except Exception:
9
+ import types
10
+ spaces = types.SimpleNamespace()
11
+ def _spaces_gpu(*args, **kwargs):
12
+ def _wrap(f): return f
13
  return _wrap
14
  spaces.GPU = _spaces_gpu
15
 
16
+ @spaces.GPU()
17
+ def _joycaption_register_gpu():
18
+ # No-op; helps Spaces detect GPU runtime
19
+ return None
20
 
21
+ import gradio as gr
22
+ import torch
23
+ from transformers import LlavaForConditionalGeneration, AutoProcessor
24
+ from PIL import Image
25
+ import tempfile, gc, os, shutil, json, time, re
26
+ from pathlib import Path
27
 
28
+ # ---------- Caches β†’ temp ----------
29
+ _tmpdir = tempfile.gettempdir()
30
+ os.environ["HF_HOME"] = os.path.join(_tmpdir, "hf_cache")
31
+ os.environ["TRANSFORMERS_CACHE"] = os.path.join(_tmpdir, "transformers_cache")
32
+ os.environ["HF_DATASETS_CACHE"] = os.path.join(_tmpdir, "datasets_cache")
33
+ os.environ["TORCH_HOME"] = os.path.join(_tmpdir, "torch_cache")
34
 
35
+ MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
36
+ SPACE_HOST = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST") or None
37
 
38
+ # ---------- Cleanup ----------
39
+ def cleanup_storage():
40
+ try:
41
+ for key in ["HF_HOME", "TRANSFORMERS_CACHE", "HF_DATASETS_CACHE", "TORCH_HOME"]:
42
+ p = os.environ.get(key)
43
+ if p and os.path.exists(p):
44
+ shutil.rmtree(p, ignore_errors=True)
45
+ gc.collect()
46
+ if torch.cuda.is_available():
47
+ torch.cuda.empty_cache()
48
+ torch.cuda.synchronize()
49
+ print("βœ… Storage cleanup completed")
50
+ except Exception as e:
51
+ print(f"⚠️ Cleanup warning: {e}")
52
 
53
+ TITLE = """
54
+ <div style="text-align:center;margin:20px 0;">
55
+ <h1>🎨 JoyCaption Three-Tone + Q&A (ZeroGPU Stable v3.1)</h1>
56
+ <p><em>All original features restored β€’ ZeroGPU-safe inference β€’ Robust decoding</em></p>
57
+ </div>
58
+ <hr>
59
+ """
60
 
61
+ print("πŸš€ Initializing JoyCaption (v3.1)...")
62
+ cleanup_storage()
63
+
64
+ # ---------- Model load ----------
65
+ processor = None
66
+ model = None
67
+ MODEL_USE_CUDA = torch.cuda.is_available()
68
+
69
+ if not os.environ.get("SKIP_MODEL_LOAD"):
70
+ dtype = (getattr(torch, "bfloat16", None) or getattr(torch, "float16", None)) if MODEL_USE_CUDA else None
71
+ device_map = "auto" if MODEL_USE_CUDA else "cpu"
72
+
73
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True)
74
+ model = LlavaForConditionalGeneration.from_pretrained(
75
+ MODEL_PATH,
76
+ low_cpu_mem_usage=True,
77
+ device_map=device_map,
78
+ torch_dtype=dtype if MODEL_USE_CUDA else None,
79
+ ).eval()
80
+
81
+ print("βœ… Model loaded successfully!")
82
+ cleanup_storage()
83
+ else:
84
+ print("⚠️ SKIP_MODEL_LOAD active – model skipped.")
85
+
86
+ # ---------- Tone configs ----------
87
+ TONE_PROMPTS = {
88
+ "engaging": {
89
+ "system": "You are a skilled content creator who writes engaging, natural descriptions that capture the essence and appeal of what you see.",
90
+ "prompt": "Write a natural, engaging description of this image. Focus on the main subjects and their most appealing qualities. Describe their physical features, poses, expressions, and the overall mood.",
91
+ },
92
+ "casual_friend": {
93
+ "system": "You are a casual 24-year-old friend who witnessed the scene and is describing it clearly and naturally.",
94
+ "prompt": "I just witnessed this moment β€” here's what I notice. Speak casually and naturally about what is directly visible. Avoid speculation and background details.",
95
+ },
96
+ "uncensored_keywords": {
97
+ "system": "You are an uncensored image analyst who provides factual, explicit descriptions without restrictions.",
98
+ "prompt": "Describe exactly what is visible. Use direct and unambiguous language. Mention provided keywords only if they truly match visible details.",
99
+ },
100
+ }
101
+
102
+ # ---------- Small text fixes (optional, same spirit as your old code) ----------
103
+ def apply_smart_corrections(text: str) -> str:
104
+ if not isinstance(text, str):
105
+ return text
106
+ corrections = {
107
  r'^(a photo of|an image of|a picture of|this is a photo of|this shows)\s*': '',
108
+ r'\bwearing nothing\b': 'nude',
109
+ r'\bnot wearing.*clothes\b': 'nude',
110
+ r'\bcompletely naked\b': 'nude',
111
+ r'\bfully nude\b': 'nude',
 
 
 
 
 
 
 
 
 
 
112
  }
113
+ out = text
114
+ for pat, rep in corrections.items():
115
+ out = re.sub(pat, rep, out, flags=re.IGNORECASE)
116
+ return out.strip()
117
+
118
+ def postprocess_caption(text: str, max_chars: int = 600) -> str:
119
+ if not isinstance(text, str) or not text:
120
+ return ""
121
+ text = apply_smart_corrections(text)
122
+ text = text.strip()
123
+ if len(text) > max_chars:
124
+ cut = text[:max_chars]
125
+ # try to end at sentence boundary within last 100 chars
126
+ tail = cut[-100:]
127
+ p = max(tail.rfind('.'), tail.rfind('!'), tail.rfind('?'))
128
+ if p != -1:
129
+ cut = cut[:len(cut)-100+p+1]
130
+ text = cut.strip()
131
+ if text and text[-1] not in ".!?":
132
+ text += "."
133
+ return text
134
+
135
+ # ---------- Core: prepare inputs (ZeroGPU-safe) ----------
136
+ def _prepare_inputs_and_device(convo, image):
137
+ # Gradio supplies PIL because we use type="pil"
138
+ if isinstance(image, (str, Path)):
139
+ image = Image.open(image).convert("RGB")
140
+ elif not isinstance(image, Image.Image):
141
+ raise ValueError("Invalid image input type")
142
+
143
+ # Build conversation string via chat template
144
+ try:
145
+ convo_string = processor.apply_chat_template(
146
+ convo, tokenize=False, add_generation_prompt=True
147
+ )
148
+ except Exception:
149
+ # Fallback: join messages
150
+ convo_string = "\n".join(str(x.get("content", "")) for x in convo)
151
+
152
+ # Tokenize + encode (always pass lists so processor returns batched tensors)
153
+ inputs = processor(text=[convo_string], images=[image], return_tensors="pt")
154
+
155
+ # Ensure batch dimension [1, ...] for every tensor (ZeroGPU requires 2D/4D shapes)
156
+ for k, v in list(inputs.items()):
157
+ if torch.is_tensor(v):
158
+ if v.ndim == 1:
159
+ v = v.unsqueeze(0) # -> [1, seq_len]
160
+ elif k == "pixel_values" and v.ndim == 3:
161
+ v = v.unsqueeze(0) # -> [1, C, H, W]
162
+ # bool masks can confuse generate(); cast to int
163
+ if v.dtype == torch.bool:
164
+ v = v.to(torch.int)
165
+ inputs[k] = v
166
+
167
+ # Move to the model device
168
+ device = next(model.parameters()).device
169
+ for k, v in inputs.items():
170
+ if torch.is_tensor(v):
171
+ inputs[k] = v.to(device, non_blocking=True)
172
+
173
+ return inputs
174
+
175
+ # ---------- Core: decode (robust to 1D/2D) ----------
176
+ def _decode_output(inputs, output):
177
+ if output is None or len(output) == 0:
178
+ return ""
179
+ try:
180
+ input_ids = inputs.get("input_ids")
181
+ input_len = input_ids.shape[-1] if (isinstance(input_ids, torch.Tensor) and input_ids.ndim > 0) else 0
182
+ text = processor.tokenizer.decode(
183
+ output[0][input_len:],
184
+ skip_special_tokens=True,
185
+ clean_up_tokenization_spaces=False,
186
+ )
187
+ return text.strip()
188
+ except Exception as e:
189
+ print(f"⚠️ Decode fallback: {e}")
190
+ try:
191
+ return processor.tokenizer.decode(output[0], skip_special_tokens=True).strip()
192
+ except Exception:
193
+ return ""
194
+
195
+ def cleanup_after_inference():
196
+ gc.collect()
197
+ if torch.cuda.is_available():
198
+ torch.cuda.empty_cache()
199
+ torch.cuda.synchronize()
200
+
201
+ # ---------- Core: generate (no invalid flags on ZeroGPU) ----------
202
+ def run_image_chat_generation(convo, image, max_new_tokens=180):
203
+ if processor is None or model is None:
204
+ return None, "❌ Model not initialized."
205
+ try:
206
+ inputs = _prepare_inputs_and_device(convo, image)
207
+
208
+ # On ZeroGPU backends, temperature/top_p may be ignored and can even trigger warnings;
209
+ # keep generation minimal & stable.
210
+ gen_kwargs = dict(
211
+ **inputs,
212
+ max_new_tokens=max_new_tokens,
213
+ pad_token_id=processor.tokenizer.eos_token_id,
214
+ eos_token_id=processor.tokenizer.eos_token_id,
215
+ )
216
+
217
+ with torch.no_grad():
218
+ output = model.generate(**gen_kwargs)
219
+
220
+ decoded = _decode_output(inputs, output)
221
+ cleanup_after_inference()
222
+ return decoded, None
223
+ except Exception as e:
224
+ cleanup_after_inference()
225
+ return None, f"❌ Generation error: {str(e)[:300]}"
226
+
227
+ # ---------- Caption helpers (features restored) ----------
228
+ def safe_generate_caption_direct(image, tone, keywords_text="", custom_instruction="", max_chars=600):
229
+ tone_conf = TONE_PROMPTS.get(tone, TONE_PROMPTS["engaging"])
230
+ base_prompt = tone_conf["prompt"]
231
+ if tone == "uncensored_keywords" and keywords_text and keywords_text.strip():
232
+ base_prompt += f"\n\nKeywords (ONLY if truly visible): {keywords_text.strip()}"
233
+ if custom_instruction and custom_instruction.strip():
234
+ base_prompt += f"\n\nInclude this detail: {custom_instruction.strip()}"
235
+
236
+ convo = [
237
+ {"role": "system", "content": tone_conf["system"]},
238
+ {"role": "user", "content": base_prompt},
239
+ ]
240
+ decoded, err = run_image_chat_generation(convo, image, max_new_tokens=220)
241
+ if err:
242
+ return err
243
+ return postprocess_caption(decoded or "", max_chars=max_chars) or "❌ Empty result"
244
+
245
+ @spaces.GPU(duration=45)
246
+ @torch.no_grad()
247
+ def generate_engaging_only(image, custom_instruction=""):
248
+ return safe_generate_caption_direct(image, "engaging", custom_instruction=custom_instruction) if image else "❌ Upload image first"
249
+
250
+ @spaces.GPU(duration=45)
251
+ @torch.no_grad()
252
+ def generate_casual_friend_only(image, custom_instruction=""):
253
+ return safe_generate_caption_direct(image, "casual_friend", custom_instruction=custom_instruction) if image else "❌ Upload image first"
254
+
255
+ @spaces.GPU(duration=45)
256
+ @torch.no_grad()
257
+ def generate_uncensored_keywords_only(image, keywords_text, custom_instruction=""):
258
+ return safe_generate_caption_direct(image, "uncensored_keywords", keywords_text=keywords_text, custom_instruction=custom_instruction) if image else "❌ Upload image first"
259
+
260
+ @spaces.GPU(duration=45)
261
+ @torch.no_grad()
262
+ def answer_question(image, question):
263
+ if not image: return "❌ Upload image first"
264
+ if not question or not question.strip(): return "❌ Please ask a question"
265
+ convo = [
266
+ {"role": "system", "content": "You are an image analyst who answers honestly and directly."},
267
+ {"role": "user", "content": f"Answer this question about the image clearly and directly: {question.strip()}"},
268
+ ]
269
+ decoded, err = run_image_chat_generation(convo, image, max_new_tokens=220)
270
+ return err if err else (decoded.strip() or "❌ No answer")
271
+
272
+ # ---------- Export ----------
273
+ def export_joycaption_data(keywords, custom_instructions, question, engaging_caption, casual_caption, keywords_caption, qa_answer, image_reference=""):
274
+ try:
275
+ data = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "source": "JoyCaption", "data": {}}
276
+ if keywords and keywords.strip(): data["data"]["keywords"] = keywords.strip()
277
+ if custom_instructions and custom_instructions.strip(): data["data"]["custom_instructions"] = custom_instructions.strip()
278
+ if question and question.strip(): data["data"]["question"] = question.strip()
279
+ if image_reference and image_reference.strip(): data["data"]["image_reference"] = image_reference.strip()
280
+ if engaging_caption and engaging_caption.strip(): data["data"]["caption_engaging"] = engaging_caption.strip()
281
+ if casual_caption and casual_caption.strip(): data["data"]["caption_casual_friend"] = casual_caption.strip()
282
+ if keywords_caption and keywords_caption.strip(): data["data"]["caption_keywords"] = keywords_caption.strip()
283
+ if qa_answer and qa_answer.strip(): data["data"]["qa_answer"] = qa_answer.strip()
284
+
285
+ if not data["data"]:
286
+ return "❌ No data to export. Generate some captions first!", None
287
+
288
+ json_string = json.dumps(data, indent=2, ensure_ascii=False)
289
+ filename = f"joycaption_data_{time.strftime('%Y%m%d_%H%M%S')}.json"
290
+ return f"βœ… Exported {len(data['data'])} fields: {', '.join(data['data'].keys())}", (json_string, filename)
291
+ except Exception as e:
292
+ return f"❌ Export failed: {str(e)}", None
293
+
294
+ # ---------- Gradio UI (full features restored) ----------
295
+ with gr.Blocks(title="JoyCaption ZeroGPU Stable", theme=gr.themes.Soft()) as demo:
296
+ gr.HTML(TITLE)
297
+
298
+ with gr.Row():
299
+ # Left
300
+ with gr.Column(scale=1):
301
+ image_input = gr.Image(type="pil", label="πŸ“Έ Upload Image", height=400)
302
+
303
+ filename_display = gr.Textbox(
304
+ label="πŸ“‚ Uploaded Filename",
305
+ interactive=False,
306
+ visible=True,
307
+ info="Auto-filled when you upload an image"
308
+ )
309
+
310
+ keywords_input = gr.Textbox(
311
+ placeholder="e.g., sensual, curves, intimate, alluring...",
312
+ label="🏷️ Keywords (used only by Uncensored tone)",
313
+ lines=2
314
+ )
315
+
316
+ image_reference_input = gr.Textbox(
317
+ placeholder="e.g., blonde_girl_001.jpg (optional override)",
318
+ label="πŸ–ΌοΈ Image Reference (Manual Override)",
319
+ lines=1
320
+ )
321
+
322
+ custom_instruction_input = gr.Textbox(
323
+ placeholder="e.g., 'from instagram', 'left girl has red hair', 'beach setting'...",
324
+ label="🎯 Make sure to mention:",
325
+ lines=2
326
+ )
327
+
328
+ question_input = gr.Textbox(
329
+ placeholder="e.g., 'What are they doing?', 'Describe her pose'...",
330
+ label="❓ Ask a Question",
331
+ lines=2
332
  )
333
 
 
334
  with gr.Row():
335
+ ask_question_btn = gr.Button("❓ Ask Question", variant="secondary", size="sm")
336
+ clear_qa_btn = gr.Button("πŸ—‘οΈ", size="sm", variant="secondary")
337
+
338
+ qa_output = gr.Textbox(
339
+ label="Q&A Answer",
340
+ lines=5,
341
+ show_copy_button=True,
342
  interactive=True,
343
+ placeholder="Q&A answers will appear here..."
344
  )
345
 
346
+ # Right
347
+ with gr.Column(scale=1):
348
+ with gr.Row():
349
+ generate_engaging_btn = gr.Button("✨ Engaging", variant="primary", size="sm")
350
+ reload_engaging = gr.Button("πŸ”„", size="sm", variant="secondary")
351
+ clear_engaging_btn = gr.Button("πŸ—‘οΈ", size="sm", variant="secondary")
352
+
353
+ engaging_output = gr.Textbox(
354
+ label="Engaging Caption",
355
+ lines=5,
356
+ show_copy_button=True,
357
+ interactive=True,
358
+ placeholder="Generate engaging caption..."
359
+ )
360
 
 
361
  with gr.Row():
362
+ generate_friend_btn = gr.Button("😎 Casual Friend", variant="primary", size="sm")
363
+ reload_friend = gr.Button("πŸ”„", size="sm", variant="secondary")
364
+ clear_friend_btn = gr.Button("πŸ—‘οΈ", size="sm", variant="secondary")
365
+
366
+ friend_output = gr.Textbox(
367
+ label="Casual Friend Caption",
368
+ lines=5,
369
+ show_copy_button=True,
370
  interactive=True,
371
+ placeholder="Generate casual caption..."
372
  )
373
 
374
+ with gr.Row():
375
+ generate_uncensored_btn = gr.Button("πŸ”΄ Uncensored + Keywords", variant="secondary", size="sm")
376
+ reload_uncensored = gr.Button("πŸ”„", size="sm", variant="secondary")
377
+ clear_uncensored_btn = gr.Button("πŸ—‘οΈ", size="sm", variant="secondary")
378
 
379
+ uncensored_output = gr.Textbox(
380
+ label="Uncensored + Keywords Caption",
381
+ lines=5,
382
+ show_copy_button=True,
383
+ interactive=True,
384
+ placeholder="Generate uncensored caption..."
385
+ )
386
 
 
387
  with gr.Row():
388
+ export_btn = gr.Button("πŸ“₯ Export All Data (JSON)", variant="primary", size="lg")
389
+
390
+ export_output = gr.Textbox(label="Export Status", lines=2, interactive=False, visible=False)
391
+ export_file = gr.File(label="Download JSON", visible=False)
392
+
393
+ # Filename extraction on upload
394
+ def extract_filename(image):
395
+ if image is None:
396
+ return ""
397
+ try:
398
+ if hasattr(image, "filename") and image.filename:
399
+ return os.path.basename(image.filename)
400
+ except Exception:
401
+ pass
402
+ return "uploaded_image.jpg"
403
 
404
+ image_input.change(extract_filename, inputs=[image_input], outputs=filename_display)
405
+
406
+ # Generation handlers
407
+ generate_engaging_btn.click(
408
+ generate_engaging_only,
409
+ inputs=[image_input, custom_instruction_input],
410
+ outputs=engaging_output,
411
+ show_progress=True
412
+ )
413
+ generate_friend_btn.click(
414
+ generate_casual_friend_only,
415
+ inputs=[image_input, custom_instruction_input],
416
+ outputs=friend_output,
417
+ show_progress=True
418
+ )
419
  generate_uncensored_btn.click(
420
  generate_uncensored_keywords_only,
421
  inputs=[image_input, keywords_input, custom_instruction_input],
422
+ outputs=uncensored_output,
423
+ show_progress=True
424
+ )
425
+
426
+ # Reload handlers
427
+ reload_engaging.click(
428
+ generate_engaging_only,
429
+ inputs=[image_input, custom_instruction_input],
430
+ outputs=engaging_output,
431
+ show_progress=True
432
  )
433
+ reload_friend.click(
434
+ generate_casual_friend_only,
435
+ inputs=[image_input, custom_instruction_input],
436
+ outputs=friend_output,
437
+ show_progress=True
438
+ )
439
+ reload_uncensored.click(
440
+ generate_uncensored_keywords_only,
441
+ inputs=[image_input, keywords_input, custom_instruction_input],
442
+ outputs=uncensored_output,
443
+ show_progress=True
444
+ )
445
+
446
+ # Q&A
447
+ ask_question_btn.click(
448
+ answer_question,
449
+ inputs=[image_input, question_input],
450
+ outputs=qa_output,
451
+ show_progress=True
452
+ )
453
+
454
+ # Clear buttons
455
+ def clear_text(): return ""
456
+ clear_qa_btn.click(clear_text, outputs=qa_output)
457
+ clear_engaging_btn.click(clear_text, outputs=engaging_output)
458
+ clear_friend_btn.click(clear_text, outputs=friend_output)
459
+ clear_uncensored_btn.click(clear_text, outputs=uncensored_output)
460
+
461
+ # Export (writes into temp dir so it works on Spaces)
462
+ def handle_export(keywords, custom_instructions, question, engaging_caption, casual_caption, keywords_caption, qa_answer, image_reference, upload_filename):
463
+ image_ref = (upload_filename or "").strip() or (image_reference or "")
464
+ message, file_data = export_joycaption_data(
465
+ keywords, custom_instructions, question,
466
+ engaging_caption, casual_caption, keywords_caption, qa_answer,
467
+ image_ref
468
+ )
469
+ if file_data:
470
+ json_string, fname = file_data
471
+ temp_path = os.path.join(tempfile.gettempdir(), fname)
472
+ with open(temp_path, "w", encoding="utf-8") as f:
473
+ f.write(json_string)
474
+ return gr.update(value=message, visible=True), gr.update(value=temp_path, visible=True)
475
+ else:
476
+ return gr.update(value=message, visible=True), gr.update(visible=False)
477
+
478
+ export_btn.click(
479
+ handle_export,
480
+ inputs=[
481
+ keywords_input, custom_instruction_input, question_input,
482
+ engaging_output, friend_output, uncensored_output, qa_output,
483
+ image_reference_input, filename_display
484
+ ],
485
+ outputs=[export_output, export_file]
486
+ ).then(lambda: gr.update(visible=True), outputs=[export_output]) \
487
+ .then(lambda: gr.update(visible=True), outputs=[export_file])
488
 
489
+ if __name__ == "__main__":
490
+ demo.launch()