nickdigger commited on
Commit
82fef69
·
verified ·
1 Parent(s): 3c448f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -858
app.py CHANGED
@@ -1,6 +1,5 @@
1
  try:
2
  import spaces
3
- # Ensure spaces.GPU exists and is a decorator
4
  if not hasattr(spaces, 'GPU'):
5
  def _spaces_gpu(*args, **kwargs):
6
  def _wrap(f):
@@ -8,7 +7,6 @@ try:
8
  return _wrap
9
  spaces.GPU = _spaces_gpu
10
  except Exception:
11
- # Provide a no-op spaces with a GPU decorator fallback so app can run outside HF Spaces
12
  import types
13
  spaces = types.SimpleNamespace()
14
  def _spaces_gpu(*args, **kwargs):
@@ -16,46 +14,34 @@ except Exception:
16
  return f
17
  return _wrap
18
  spaces.GPU = _spaces_gpu
19
-
20
- # Some Spaces runtimes require at least one function decorated with @spaces.GPU
21
- # Register a no-op GPU-decorated function so the platform detection succeeds.
22
  @spaces.GPU()
23
  def _joycaption_register_gpu():
24
- """No-op function decorated with @spaces.GPU to satisfy Spaces startup detection."""
25
  return None
 
26
  import gradio as gr
27
  import torch
28
  from transformers import LlavaForConditionalGeneration, AutoProcessor
29
  from PIL import Image
30
- import tempfile
31
- import gc
32
- import time
33
- import gc
34
- import os
35
- import shutil
36
- import json
37
  from pathlib import Path
38
- import re
39
-
40
  from hf_space_utils import fix_image_url, postprocess_caption
41
 
42
- # Storage optimization - redirect cache to temporary directories (platform independent)
43
  _tmpdir = tempfile.gettempdir()
44
  os.environ["HF_HOME"] = os.path.join(_tmpdir, "hf_cache")
45
  os.environ["TRANSFORMERS_CACHE"] = os.path.join(_tmpdir, "transformers_cache")
46
  os.environ["HF_DATASETS_CACHE"] = os.path.join(_tmpdir, "datasets_cache")
47
  os.environ["TORCH_HOME"] = os.path.join(_tmpdir, "torch_cache")
48
 
49
- # Model configuration
50
  MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
51
-
52
- # Optional public host for converting /tmp/gradio paths to public gradio_api URLs
53
  SPACE_HOST = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST") or None
54
 
 
55
  def cleanup_storage():
56
- """Clean up temporary files and caches to prevent storage overflow"""
57
  try:
58
- # Clean up temporary caches using the configured environment paths
59
  temp_dirs = [
60
  os.environ.get("HF_HOME"),
61
  os.environ.get("TRANSFORMERS_CACHE"),
@@ -63,26 +49,15 @@ def cleanup_storage():
63
  os.environ.get("TORCH_HOME")
64
  ]
65
  for temp_dir in temp_dirs:
66
- if not temp_dir:
67
- continue
68
- if os.path.exists(temp_dir):
69
- try:
70
- shutil.rmtree(temp_dir, ignore_errors=True)
71
- except Exception:
72
- # best-effort cleanup
73
- pass
74
-
75
- # Force garbage collection
76
  gc.collect()
77
-
78
- # Clear GPU cache if available
79
  if torch.cuda.is_available():
80
  torch.cuda.empty_cache()
81
  torch.cuda.synchronize()
82
-
83
  print("✅ Storage cleanup completed")
84
  except Exception as e:
85
- print(f"⚠️ Storage cleanup warning: {e}")
86
 
87
  TITLE = """
88
  <div style="text-align: center; margin: 20px 0;">
@@ -93,26 +68,18 @@ TITLE = """
93
  <hr>
94
  """
95
 
96
- print("🚀 Loading Sequential Three-Tone JoyCaption system... v2.1")
 
97
 
98
- # Load model and processor at startup
99
- print("📦 Loading model and processor at startup...")
100
  processor = None
101
  model = None
102
  MODEL_TORCH_DTYPE = None
103
  MODEL_USE_CUDA = False
104
- # Force cleanup before loading model to avoid tokenizer desync
105
- cleanup_storage()
106
- torch.cuda.empty_cache()
107
- gc.collect()
108
 
109
-
110
- # Allow skipping model loading for tests or light-weight runs by setting SKIP_MODEL_LOAD=1
111
  if not os.environ.get("SKIP_MODEL_LOAD"):
112
- # Determine target device for model loading. On zero-GPU spaces, fall back to CPU.
113
  use_cuda = torch.cuda.is_available()
114
  if use_cuda:
115
- # Prefer bf16 on supported GPUs, otherwise try float16
116
  torch_dtype = getattr(torch, 'bfloat16', None) or getattr(torch, 'float16', None)
117
  device_map = "auto"
118
  MODEL_USE_CUDA = True
@@ -121,912 +88,182 @@ if not os.environ.get("SKIP_MODEL_LOAD"):
121
  device_map = "cpu"
122
  MODEL_USE_CUDA = False
123
 
124
- processor = AutoProcessor.from_pretrained(
125
- MODEL_PATH,
126
- low_cpu_mem_usage=True
127
- )
128
-
129
  model_kwargs = dict(low_cpu_mem_usage=True, device_map=device_map)
130
  if torch_dtype is not None and use_cuda:
131
  model_kwargs['torch_dtype'] = torch_dtype
132
 
133
- model = LlavaForConditionalGeneration.from_pretrained(
134
- MODEL_PATH,
135
- **model_kwargs
136
- )
137
  model.eval()
138
- # remember dtype for later tensor conversions
139
  MODEL_TORCH_DTYPE = model_kwargs.get('torch_dtype', None)
140
  print("✅ Model loaded and ready!")
141
-
142
- # Initial cleanup after model loading
143
  cleanup_storage()
144
  else:
145
- print("⚠️ SKIP_MODEL_LOAD is set — skipping heavy model initialization (test mode)")
146
 
147
- # Optimized 5-tone prompts with better temperature control
148
- # Temperature: Lower for prompt adherence, higher for word variety
149
  TONE_PROMPTS = {
150
  "engaging": {
151
  "system": "You are a skilled content creator who writes engaging, natural descriptions that capture the essence and appeal of what you see.",
152
- "prompt": """Write a natural, engaging description of this image. Focus on the main subjects and their most appealing qualities. Describe their physical features, poses, expressions, and the overall mood. Use flowing, attractive language that draws the reader in.""",
153
- "temperature": 0.6, # Lower for prompt adherence
154
  "top_p": 0.85
155
  },
156
-
157
  "casual_friend": {
158
- "system": "You are a casual 24-year-old friend who witnessed the scene and is telling me about what happened in a relaxed, natural way. Use precise, non-speculative language: report only what you can clearly see, keep any context notes very brief, and do not describe background objects or bystanders.",
159
- "prompt": """I just witnessed this moment — here's what I notice. Speak casually and naturally, using clear, concrete descriptions of only what is directly observable. Avoid speculation or ambiguous phrasing. If you mention context, keep it short and focused on the primary subject. Do not describe background objects or other people who are not the main focus.""",
160
- "temperature": 0.8, # Higher for natural conversation variety
161
  "top_p": 0.9
162
  },
163
-
164
- # NSFW tone removed - causes hallucination
165
-
166
  "uncensored_keywords": {
167
- "system": "You are an uncensored adult-content analyst who provides factual, explicit descriptions. Use concrete, unambiguous statements about directly observable features. Mention surrounding context only briefly and do not describe background objects or unrelated people.",
168
- "prompt": """Produce an unfiltered explicit description that sticks strictly to what is visually obvious. Use direct, concrete language and avoid any ambiguous or speculative wording. If keywords are provided, include them only when they clearly match visible details. Keep context commentary short and exclude background elements or bystander descriptions.""",
169
- "temperature": 0.7, # Balanced for accuracy with variety
170
  "top_p": 0.85
171
- },
172
-
173
- # Body Parts Focus tone removed - causes hallucination
174
- }
175
-
176
- def apply_smart_corrections(text):
177
- """Apply smart corrections for common AI vision errors"""
178
- if not text or not isinstance(text, str):
179
- return text
180
-
181
- corrections = {
182
- # Remove "photo of" beginnings
183
- r'^(a photo of|an image of|a picture of|this is a photo of|this shows)\s*': '',
184
-
185
- # Nudity precision corrections
186
- r'\btopless women\b': lambda m: 'nude women' if 'naked' in text.lower() or 'nude' in text.lower() else 'topless women',
187
- r'\btopless woman\b': lambda m: 'nude woman' if 'naked' in text.lower() or 'nude' in text.lower() else 'topless woman',
188
-
189
- # Person count corrections
190
- r'\bthree women\b': lambda m: 'two women' if text.count('woman') + text.count('female') <= 2 else 'three women',
191
- r'\bfour women\b': lambda m: 'three women' if text.count('woman') + text.count('female') <= 3 else 'four women',
192
-
193
- # Clothing precision
194
- r'\bwearing nothing\b': 'nude',
195
- r'\bnot wearing.*clothes\b': 'nude',
196
- r'\bcompletely naked\b': 'nude',
197
- r'\bfully nude\b': 'nude',
198
  }
199
-
200
- corrected_text = text
201
- try:
202
- for pattern, replacement in corrections.items():
203
- if callable(replacement):
204
- # Wrap the replacement to ensure it returns a string and accepts a Match
205
- def _repl(match, rep=replacement):
206
- try:
207
- out = rep(match)
208
- return "" if out is None else str(out)
209
- except Exception:
210
- return match.group(0)
211
- corrected_text = re.sub(pattern, _repl, corrected_text, flags=re.IGNORECASE)
212
- else:
213
- corrected_text = re.sub(pattern, replacement, corrected_text, flags=re.IGNORECASE)
214
- except Exception as e:
215
- print(f"Error in smart corrections: {e}")
216
- return text
217
-
218
- return corrected_text
219
 
 
220
  def _prepare_inputs_and_device(convo_or_convo_string, image):
221
- """Prepare processor inputs and move tensors to the model device."""
222
- # Accept either a convo list or a pre-built convo string
 
 
 
 
223
  convo_string = convo_or_convo_string
224
- try:
225
- if isinstance(convo_or_convo_string, list):
226
  convo_string = processor.apply_chat_template(convo_or_convo_string, tokenize=False, add_generation_prompt=True)
227
- except Exception:
228
- # If processor is not ready or fails, let the caller handle the missing model/processor
229
- pass
230
 
231
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt")
232
  device = next(model.parameters()).device
233
  inputs = {k: v.to(device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
234
 
235
- # Ensure pixel_values dtype is safe for the runtime
236
  if 'pixel_values' in inputs:
237
- if MODEL_USE_CUDA and MODEL_TORCH_DTYPE is not None:
238
- try:
239
- inputs['pixel_values'] = inputs['pixel_values'].to(MODEL_TORCH_DTYPE)
240
- except Exception:
241
- inputs['pixel_values'] = inputs['pixel_values'].to(torch.float32)
242
- else:
243
- inputs['pixel_values'] = inputs['pixel_values'].to(torch.float32)
244
  return inputs
245
 
 
246
  def _decode_output(inputs, output):
247
- """Decode generate output safely, removing input prompt tokens if present."""
248
- if output is None or len(output) == 0:
249
  return ""
250
  try:
251
- if 'input_ids' in inputs and len(inputs['input_ids'].shape) >= 2:
252
- input_length = inputs['input_ids'].shape[1]
253
- if len(output[0]) > input_length:
254
- generate_ids = output[0][input_length:]
255
- return processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False).strip()
256
- else:
257
- return processor.tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False).strip()
258
- else:
259
- return processor.tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False).strip()
260
  except Exception:
261
- # Fallback: try a direct decode
262
- try:
263
- return processor.tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False).strip()
264
- except Exception:
265
- return ""
266
 
267
- def cleanup_after_inference(inputs=None, output=None):
268
- """Lightweight cleanup after an inference run."""
269
- try:
270
- del inputs
271
- except Exception:
272
- pass
273
- try:
274
- del output
275
- except Exception:
276
- pass
277
- try:
278
- if torch.cuda.is_available():
279
- torch.cuda.empty_cache()
280
- torch.cuda.synchronize()
281
- except Exception:
282
- pass
283
  gc.collect()
284
 
 
285
  def run_image_chat_generation(convo, image, max_new_tokens=150, temperature=0.7, top_p=0.9):
286
- """
287
- Centralized helper to run an image+chat generation using the loaded processor/model.
288
- Returns the decoded string (possibly empty) or None and an error message string on failure.
289
- """
290
  if processor is None or model is None:
291
- return None, "❌ Model or processor not initialized. Make sure model is loaded (unset SKIP_MODEL_LOAD) and dependencies are installed."
292
 
293
  try:
294
- # Prepare inputs
295
  inputs = _prepare_inputs_and_device(convo, image)
296
-
297
- # Run generation
298
  with torch.no_grad():
299
  output = model.generate(
300
  **inputs,
301
  max_new_tokens=max_new_tokens,
302
- do_sample=True,
303
  temperature=temperature,
304
  top_p=top_p,
 
305
  use_cache=True,
306
  pad_token_id=processor.tokenizer.eos_token_id,
307
  eos_token_id=processor.tokenizer.eos_token_id
308
  )
309
-
310
  decoded = _decode_output(inputs, output)
311
-
312
- # Cleanup
313
- cleanup_after_inference(inputs, output)
314
-
315
  return decoded, None
316
  except Exception as e:
317
- try:
318
- cleanup_after_inference(None, None)
319
- except:
320
- pass
321
- return None, f"❌ Error during generation: {str(e)[:200]}"
322
 
 
323
  def safe_generate_caption_direct(image, tone, max_chars=600, keywords_text="", custom_instruction=""):
324
- """Generate caption directly with keywords and custom instructions support"""
325
  try:
326
- if image is None:
327
- return f"❌ No image provided for {tone} caption"
328
-
329
- # Get tone configuration
330
- tone_config = TONE_PROMPTS.get(tone, TONE_PROMPTS["engaging"])
331
-
332
- # Modify prompt based on tone and provided keywords/instructions
333
- base_prompt = tone_config["prompt"]
334
-
335
- # Add keywords instruction for uncensored_keywords tone
336
- if tone == "uncensored_keywords" and keywords_text and keywords_text.strip():
337
- base_prompt += f"\n\nKeywords to mention IF applicable: {keywords_text.strip()}"
338
-
339
- # Add custom instruction to any tone if provided
340
- if custom_instruction and custom_instruction.strip():
341
- base_prompt += f"\n\nMake sure to mention: {custom_instruction.strip()}\nInclude this detail naturally in your description."
342
-
343
- # Create conversation
344
  convo = [
345
- {"role": "system", "content": tone_config["system"]},
346
  {"role": "user", "content": base_prompt}
347
  ]
348
-
349
- # Ensure model and processor are loaded
350
- if processor is None or model is None:
351
- return "❌ Model or processor not initialized. Make sure model is loaded (unset SKIP_MODEL_LOAD) and dependencies are installed."
352
 
353
- # Use centralized generation helper
354
- decoded, err = run_image_chat_generation(convo, image, max_new_tokens=150, temperature=tone_config.get("temperature", 0.7), top_p=tone_config.get("top_p", 0.9))
355
  if err:
356
  return err
357
- result = (decoded or "").strip()
358
-
359
- # Post-process caption (sanitize + smart corrections + truncation)
360
- result = postprocess_caption(result, max_chars=max_chars)
361
-
362
- # Aggressive cleanup to prevent storage overflow
363
- del inputs, output
364
- if torch.cuda.is_available():
365
- torch.cuda.empty_cache()
366
- torch.cuda.synchronize()
367
- gc.collect()
368
-
369
- return result if result else f"❌ Empty result for {tone}"
370
-
371
  except Exception as e:
372
- try:
373
- if torch.cuda.is_available():
374
- torch.cuda.empty_cache()
375
- torch.cuda.synchronize()
376
- gc.collect()
377
- except:
378
- pass
379
- return f"❌ Error: {str(e)[:50]}..."
380
 
381
  @torch.no_grad()
382
  def generate_engaging_only(image, custom_instruction=""):
383
- """Generate only engaging caption"""
384
  return safe_generate_caption_direct(image, "engaging", custom_instruction=custom_instruction) if image else "❌ Upload image first"
385
 
386
  @torch.no_grad()
387
  def generate_casual_friend_only(image, custom_instruction=""):
388
- """Generate only casual friend caption"""
389
  return safe_generate_caption_direct(image, "casual_friend", custom_instruction=custom_instruction) if image else "❌ Upload image first"
390
 
391
- # NSFW function removed - caused hallucination
392
-
393
  @torch.no_grad()
394
  def generate_uncensored_keywords_only(image, keywords_text, custom_instruction=""):
395
- """Generate only uncensored with keywords caption"""
396
  return safe_generate_caption_direct(image, "uncensored_keywords", keywords_text=keywords_text, custom_instruction=custom_instruction) if image else "❌ Upload image first"
397
 
398
- @torch.no_grad()
399
- def generate_test_tone(image, system_text, prompt_text):
400
- """Generate a user-defined test tone using the provided system and prompt.
401
- This output is for testing only and is NOT included in the exported JSON."""
402
- try:
403
- if image is None:
404
- return "❌ Upload image first"
405
- # Fallback defaults if user cleared the fields
406
- if not system_text or not system_text.strip():
407
- system_text = "You are a ... who ..."
408
- if not prompt_text or not prompt_text.strip():
409
- prompt_text = ("Analyze this image like an art critic would with information about its composition, "
410
- "style, symbolism, the use of color, light, any artistic movement it might belong to, etc. "
411
- "Keep it long. Write a medium-length caption for this image as if it were being used for a social media post.")
412
-
413
- convo = [
414
- {"role": "system", "content": system_text},
415
- {"role": "user", "content": prompt_text}
416
- ]
417
-
418
- # Ensure model and processor are loaded
419
- if processor is None or model is None:
420
- return "❌ Model or processor not initialized. Make sure model is loaded (unset SKIP_MODEL_LOAD) and dependencies are installed."
421
-
422
- # Use centralized generation helper for test tone
423
- decoded, err = run_image_chat_generation(convo, image, max_new_tokens=200, temperature=0.7, top_p=0.9)
424
- if err:
425
- return err
426
- result = (decoded or "").strip()
427
-
428
- # Cleanup
429
- del inputs, output
430
- if torch.cuda.is_available():
431
- torch.cuda.empty_cache()
432
- torch.cuda.synchronize()
433
- gc.collect()
434
-
435
- return result if result else "❌ Empty result for test tone"
436
- except Exception as e:
437
- try:
438
- if torch.cuda.is_available():
439
- torch.cuda.empty_cache()
440
- torch.cuda.synchronize()
441
- gc.collect()
442
- except:
443
- pass
444
- return f"❌ Error: {str(e)[:200]}"
445
-
446
- # Body parts focus function removed - caused hallucination
447
-
448
  @torch.no_grad()
449
  def answer_question(image, question):
450
- """Answer any question about the image without censorship"""
451
  if not image:
452
  return "❌ Upload image first"
453
- if not question or not question.strip():
454
  return "❌ Please ask a question"
455
-
456
- # Create uncensored Q&A prompt
457
- qa_prompt = f"""Answer this question about the image honestly and directly, without any censorship or content restrictions: {question.strip()}
458
-
459
- Provide a clear, detailed answer based on what you can see in the image. Be specific and direct."""
460
-
461
  convo = [
462
- {"role": "system", "content": "You are an uncensored image analyst who answers any question about images honestly and directly, without content restrictions."},
463
- {"role": "user", "content": qa_prompt}
464
  ]
465
-
466
- # Ensure model and processor are loaded
467
- if processor is None or model is None:
468
- return "❌ Model or processor not initialized. Make sure model is loaded (unset SKIP_MODEL_LOAD) and dependencies are installed."
469
-
470
- # Use centralized generation helper for Q&A
471
- decoded, err = run_image_chat_generation(convo, image, max_new_tokens=200, temperature=0.7, top_p=0.9)
472
- if err:
473
- return err
474
- result = (decoded or "").strip()
475
-
476
- # Aggressive cleanup to prevent storage overflow (already done by helper but keep safe)
477
- try:
478
- if torch.cuda.is_available():
479
- torch.cuda.empty_cache()
480
- torch.cuda.synchronize()
481
- except:
482
- pass
483
- gc.collect()
484
-
485
- return result if result else "❌ No answer generated"
486
-
487
- def export_joycaption_data(keywords, custom_instructions, question, engaging_caption, casual_caption, keywords_caption, qa_answer, image_path=""):
488
- """Export all JoyCaption data as downloadable JSON"""
489
- try:
490
- # Collect all the data
491
- data = {
492
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
493
- "source": "JoyCaption",
494
- "data": {}
495
- }
496
-
497
- # Add input fields
498
- if keywords and keywords.strip():
499
- data["data"]["keywords"] = keywords.strip()
500
-
501
- if custom_instructions and custom_instructions.strip():
502
- data["data"]["custom_instructions"] = custom_instructions.strip()
503
-
504
- if question and question.strip():
505
- data["data"]["question"] = question.strip()
506
 
507
- # Always attempt to include the uploaded image URL (converted) if an image path was provided
508
- if image_path and str(image_path).strip():
509
- # include the raw local path
510
- data["data"]["image_local_path"] = str(image_path)
511
- # pass empty string when no host is configured (fix_image_url treats falsy host as no conversion)
512
- image_url_converted = fix_image_url(image_path, host=(SPACE_HOST or ""))
513
- if image_url_converted and str(image_url_converted).strip():
514
- data["data"]["image_url"] = str(image_url_converted).strip()
515
- # Add generated captions
516
- if engaging_caption and engaging_caption.strip():
517
- data["data"]["caption_engaging"] = engaging_caption.strip()
518
-
519
- if casual_caption and casual_caption.strip():
520
- data["data"]["caption_casual_friend"] = casual_caption.strip()
521
-
522
- if keywords_caption and keywords_caption.strip():
523
- data["data"]["caption_keywords"] = keywords_caption.strip()
524
-
525
- if qa_answer and qa_answer.strip():
526
- data["data"]["qa_answer"] = qa_answer.strip()
527
-
528
- # Check if we have any data to export
529
- if not data["data"]:
530
- return "❌ No data export. Generate some captions first!", None
531
-
532
- # Create JSON string
533
- json_string = json.dumps(data, indent=2, ensure_ascii=False)
534
-
535
- # Create filename with timestamp
536
- filename = f"joycaption_data_{time.strftime('%Y%m%d_%H%M%S')}.json"
537
-
538
- # Return success message and file data
539
- fields_count = len(data["data"])
540
- return f"✅ Exported {fields_count} fields: {', '.join(data['data'].keys())}", (json_string, filename)
541
-
542
- except Exception as e:
543
- return f"❌ Export failed: {str(e)}", None
544
-
545
- # JavaScript for export functionality
546
- EXPORT_JS = """
547
- <script>
548
- // JoyCaption Export System
549
- (function() {
550
- console.log('🚀 Initializing JoyCaption Export System...');
551
-
552
- // Extract data from page fields
553
- window.getJoyCaptionData = function() {
554
- console.log('📊 Extracting JoyCaption data...');
555
- const data = {};
556
-
557
- // Get all textareas and inputs from the page
558
- const allInputs = document.querySelectorAll('textarea, input[type="text"]');
559
-
560
- allInputs.forEach((field, index) => {
561
- const placeholder = (field.placeholder || '').toLowerCase();
562
- const value = field.value ? field.value.trim() : '';
563
-
564
- // Skip empty fields
565
- if (!value) return;
566
-
567
- // Map based on placeholder text and content length
568
- if (placeholder.includes('engaging') || (value.length > 50 && placeholder.includes('generate engaging'))) {
569
- data.caption_engaging = value;
570
- } else if (placeholder.includes('casual') || placeholder.includes('friend') || (value.length > 50 && placeholder.includes('generate casual'))) {
571
- data.caption_casual_friend = value;
572
- } else if (placeholder.includes('keyword') && value.length > 50) {
573
- data.caption_keywords = value;
574
- } else if (placeholder.includes('keyword') && value.length <= 50) {
575
- data.keywords = value;
576
- } else if (placeholder.includes('custom') || placeholder.includes('make sure') || placeholder.includes('mention')) {
577
- data.custom_instructions = value;
578
- } else if (placeholder.includes('question')) {
579
- data.question = value;
580
- } else if (value.length > 50) {
581
- // Long text likely a caption
582
- if (!data.caption_engaging) data.caption_engaging = value;
583
- else if (!data.caption_casual_friend) data.caption_casual_friend = value;
584
- else if (!data.caption_keywords) data.caption_keywords = value;
585
- }
586
- });
587
-
588
- // Add image URLs if present
589
- const images = document.querySelectorAll('img');
590
- const imageUrls = [];
591
- images.forEach(img => {
592
- if (img.src && !img.src.includes('data:') && !img.src.includes('blob:')) {
593
- imageUrls.push(img.src);
594
- }
595
- });
596
-
597
- if (imageUrls.length > 0) {
598
- data.image_urls = imageUrls;
599
- }
600
-
601
- console.log('📦 Extracted data:', data);
602
- return data;
603
- };
604
-
605
- // Listen for extension requests
606
- window.addEventListener('message', function(event) {
607
- if (event.data && event.data.action === 'getJoyCaptionData') {
608
- const data = window.getJoyCaptionData();
609
- event.source.postMessage({
610
- action: 'joyCaptionData',
611
- data: data,
612
- success: Object.keys(data).length > 0
613
- }, event.origin);
614
- }
615
- });
616
-
617
- // Export functionality
618
- window.downloadJoyCaptionData = function() {
619
- try {
620
- const rawData = window.getJoyCaptionData();
621
-
622
- if (Object.keys(rawData).length === 0) {
623
- alert('❌ No data found to export. Make sure you have generated captions first.');
624
- return;
625
- }
626
-
627
- // Package data for export
628
- const exportData = {
629
- timestamp: new Date().toISOString(),
630
- source: 'JoyCaption',
631
- data: rawData
632
- };
633
-
634
- // Create and download JSON file
635
- const jsonString = JSON.stringify(exportData, null, 2);
636
- const blob = new Blob([jsonString], { type: 'application/json' });
637
- const url = URL.createObjectURL(blob);
638
-
639
- const a = document.createElement('a');
640
- a.href = url;
641
- a.download = `joycaption_data_${new Date().toISOString().slice(0, 16).replace(/:/g, '-')}.json`;
642
- document.body.appendChild(a);
643
- a.click();
644
- document.body.removeChild(a);
645
- URL.revokeObjectURL(url);
646
-
647
- alert(`✅ Downloaded JoyCaption data with ${Object.keys(rawData).length} fields!`);
648
- console.log('📥 Downloaded data:', exportData);
649
-
650
- } catch (error) {
651
- console.error('❌ Export error:', error);
652
- alert('❌ Export failed: ' + error.message);
653
- }
654
- };
655
-
656
- // Create export button
657
- function createExportButton() {
658
- // Remove any existing button first
659
- const existingBtn = document.getElementById('joyCaption-export-btn');
660
- if (existingBtn) existingBtn.remove();
661
-
662
- // Create a floating export button
663
- const exportBtn = document.createElement('button');
664
- exportBtn.id = 'joyCaption-export-btn';
665
- exportBtn.innerHTML = '📥 Export JoyCaption Data';
666
- exportBtn.style.cssText = `
667
- position: fixed;
668
- top: 20px;
669
- right: 20px;
670
- z-index: 9999;
671
- background: linear-gradient(135deg, #ff6b35, #f7931e);
672
- color: white;
673
- border: none;
674
- padding: 12px 20px;
675
- border-radius: 25px;
676
- font-weight: 600;
677
- cursor: pointer;
678
- box-shadow: 0 4px 12px rgba(255, 107, 53, 0.3);
679
- transition: all 0.3s ease;
680
- `;
681
-
682
- exportBtn.addEventListener('mouseover', () => {
683
- exportBtn.style.transform = 'translateY(-2px)';
684
- exportBtn.style.boxShadow = '0 6px 16px rgba(255, 107, 53, 0.4)';
685
- });
686
-
687
- exportBtn.addEventListener('mouseout', () => {
688
- exportBtn.style.transform = 'translateY(0)';
689
- exportBtn.style.boxShadow = '0 4px 12px rgba(255, 107, 53, 0.3)';
690
- });
691
-
692
- exportBtn.addEventListener('click', window.downloadJoyCaptionData);
693
-
694
- document.body.appendChild(exportBtn);
695
- console.log('✅ Export button created and attached to body');
696
- }
697
-
698
- // Multiple attempts to create button after Gradio loads
699
- setTimeout(createExportButton, 1000);
700
- setTimeout(createExportButton, 3000);
701
- setTimeout(createExportButton, 5000);
702
-
703
- // Also try when DOM changes (Gradio dynamic loading)
704
- const observer = new MutationObserver(() => {
705
- if (!document.getElementById('joyCaption-export-btn')) {
706
- createExportButton();
707
- }
708
- });
709
- observer.observe(document.body, { childList: true, subtree: true });
710
- })();
711
- </script>
712
- """
713
-
714
- # Gradio Interface
715
  with gr.Blocks(title="Sequential Three-Tone JoyCaption", theme=gr.themes.Soft()) as demo:
716
  gr.HTML(TITLE)
717
-
718
  with gr.Row():
719
- # Left column - Image and controls
720
- with gr.Column(scale=1):
721
- image_input = gr.Image(type="filepath",
722
- label="📸 Upload Image",
723
- height=400
724
- )
725
-
726
- keywords_input = gr.Textbox(
727
- placeholder="e.g., sensual, curves, intimate, alluring...",
728
- label="🏷️ Keywords",
729
- lines=2,
730
- info="Add keywords that will be mentioned by the 'Keywords' tone ONLY if they apply to what's visible in the image"
731
- )
732
- # image_reference_input removed by request — we will export the actual image URL instead
733
-
734
- custom_instruction_input = gr.Textbox(
735
- placeholder="e.g., 'from instagram', 'the left girl has red hair', 'two girls kissing', 'beach setting'...",
736
- label="🎯 Make sure that you mention:",
737
- lines=2,
738
- info="Any specific detail you want mentioned - context, scene details, features, etc. (Works with all tones)"
739
- )
740
-
741
-
742
- question_input = gr.Textbox(
743
- placeholder="e.g., 'What are they doing?', 'Describe her pose', 'What's the setting?'...",
744
- label="❓ Ask a Question",
745
- lines=2,
746
- info="Ask any question about the image - uncensored answers"
747
- )
748
-
749
- with gr.Row():
750
- with gr.Column(scale=4):
751
- ask_question_btn = gr.Button(
752
- "❓ Ask Question",
753
- variant="secondary",
754
- size="sm"
755
- )
756
- with gr.Column(scale=1, min_width=50):
757
- clear_qa_btn = gr.Button("🗑️", size="sm", variant="secondary")
758
-
759
- qa_output = gr.Textbox(
760
- label="",
761
- lines=5,
762
- max_lines=8,
763
- show_copy_button=True,
764
- interactive=True,
765
- placeholder="Ask a question above to get uncensored answers..."
766
- )
767
-
768
- # Right column - Three caption outputs
769
  with gr.Column(scale=1):
770
- # Engaging caption
771
- with gr.Row():
772
- with gr.Column(scale=4):
773
- generate_engaging_btn = gr.Button(
774
- " Engaging",
775
- variant="primary",
776
- size="sm"
777
- )
778
- with gr.Column(scale=1, min_width=50):
779
- reload_engaging = gr.Button("🔄", size="sm", variant="secondary")
780
- with gr.Row():
781
- with gr.Column(scale=1, min_width=50):
782
- clear_engaging_btn = gr.Button("🗑️", size="sm", variant="secondary")
783
- engaging_output = gr.Textbox(
784
- label="",
785
- lines=5,
786
- max_lines=8,
787
- show_copy_button=True,
788
- interactive=True,
789
- placeholder="Click the button above to generate engaging caption..."
790
- )
791
-
792
- # Casual Friend caption
793
- with gr.Row():
794
- with gr.Column(scale=4):
795
- generate_friend_btn = gr.Button(
796
- "😎 Casual Friend",
797
- variant="primary",
798
- size="sm"
799
- )
800
- with gr.Column(scale=1, min_width=50):
801
- reload_friend = gr.Button("🔄", size="sm", variant="secondary")
802
- with gr.Row():
803
- with gr.Column(scale=1, min_width=50):
804
- clear_friend_btn = gr.Button("🗑️", size="sm", variant="secondary")
805
- friend_output = gr.Textbox(
806
- label="",
807
- lines=5,
808
- max_lines=8,
809
- show_copy_button=True,
810
- interactive=True,
811
- placeholder="Click the button above to generate casual friend caption..."
812
- )
813
-
814
- # NSFW section removed - caused hallucination
815
-
816
- # Keywords caption
817
- with gr.Row():
818
- with gr.Column(scale=4):
819
- generate_uncensored_btn = gr.Button(
820
- "🔴 Keywords",
821
- variant="secondary",
822
- size="sm"
823
- )
824
- with gr.Column(scale=1, min_width=50):
825
- reload_uncensored = gr.Button("🔄", size="sm", variant="secondary")
826
- with gr.Row():
827
- with gr.Column(scale=1, min_width=50):
828
- clear_uncensored_btn = gr.Button("🗑️", size="sm", variant="secondary")
829
- uncensored_output = gr.Textbox(
830
- label="",
831
- lines=5,
832
- max_lines=8,
833
- show_copy_button=True,
834
- interactive=True,
835
- placeholder="Click the button above to generate keywords caption..."
836
- )
837
-
838
- # Body Parts Focus section removed - caused hallucination
839
-
840
- # Descriptive text removed for cleaner interface
841
-
842
- # Export functionality
843
- with gr.Row():
844
- export_btn = gr.Button(
845
- "📥 Export All Data (JSON)",
846
- variant="primary",
847
- size="lg"
848
- )
849
-
850
- export_output = gr.Textbox(
851
- label="Export Status",
852
- lines=2,
853
- interactive=False,
854
- visible=False
855
- )
856
-
857
- export_file = gr.File(
858
- label="Download JSON",
859
- visible=False
860
- )
861
-
862
- # --- Test Tone UI (for local testing only; not included in exported JSON) ---
863
- system_test = gr.Textbox(
864
- label="Test Tone System",
865
- value="You are a ... who ...",
866
- lines=1,
867
- info="Editable system message for the test tone"
868
- )
869
- prompt_test = gr.Textbox(
870
- label="Test Tone Prompt",
871
- value="""Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it long. Write a medium-length caption for this image as if it were being used for a social media post.""",
872
- lines=6,
873
- info="Editable user prompt for the test tone"
874
- )
875
- test_btn = gr.Button(
876
- "🔬 Test Tone",
877
- variant="secondary",
878
- size="sm"
879
- )
880
- test_output = gr.Textbox(
881
- label="Test Tone Output",
882
- lines=5,
883
- max_lines=8,
884
- show_copy_button=True,
885
- interactive=True,
886
- placeholder="Click the button above to run the test tone..."
887
- )
888
-
889
- # Individual generate button handlers
890
- generate_engaging_btn.click(
891
- generate_engaging_only,
892
- inputs=[image_input, custom_instruction_input],
893
- outputs=engaging_output,
894
- show_progress=True
895
- )
896
-
897
- generate_friend_btn.click(
898
- generate_casual_friend_only,
899
- inputs=[image_input, custom_instruction_input],
900
- outputs=friend_output,
901
- show_progress=True
902
- )
903
-
904
- # NSFW button handler removed
905
-
906
- generate_uncensored_btn.click(
907
- generate_uncensored_keywords_only,
908
- inputs=[image_input, keywords_input, custom_instruction_input],
909
- outputs=uncensored_output,
910
- show_progress=True
911
- )
912
-
913
- test_btn.click(
914
- generate_test_tone,
915
- inputs=[image_input, system_test, prompt_test],
916
- outputs=test_output,
917
- show_progress=True
918
- )
919
-
920
- # Body Parts Focus button handler removed
921
-
922
- # Individual reload buttons - using direct generation for consistency
923
- def reload_engaging_fn(image, custom_instruction):
924
- return generate_engaging_only(image, custom_instruction) if image else "❌ Upload image first"
925
-
926
- def reload_friend_fn(image, custom_instruction):
927
- return generate_casual_friend_only(image, custom_instruction) if image else "❌ Upload image first"
928
-
929
- # NSFW reload function removed
930
-
931
- def reload_uncensored_fn(image, keywords, custom_instruction):
932
- return generate_uncensored_keywords_only(image, keywords, custom_instruction) if image else "❌ Upload image first"
933
-
934
- # Body Parts Focus reload function removed
935
-
936
- reload_engaging.click(
937
- reload_engaging_fn,
938
- inputs=[image_input, custom_instruction_input],
939
- outputs=engaging_output,
940
- show_progress=True
941
- )
942
-
943
- reload_friend.click(
944
- reload_friend_fn,
945
- inputs=[image_input, custom_instruction_input],
946
- outputs=friend_output,
947
- show_progress=True
948
- )
949
-
950
- # NSFW reload click handler removed
951
-
952
- reload_uncensored.click(
953
- reload_uncensored_fn,
954
- inputs=[image_input, keywords_input, custom_instruction_input],
955
- outputs=uncensored_output,
956
- show_progress=True
957
- )
958
-
959
- # Body Parts Focus reload click handler removed
960
-
961
- # Q&A functionality
962
- ask_question_btn.click(
963
- answer_question,
964
- inputs=[image_input, question_input],
965
- outputs=qa_output,
966
- show_progress=True
967
- )
968
-
969
- # Clear button functions
970
- def clear_text():
971
- return ""
972
-
973
- clear_qa_btn.click(
974
- clear_text,
975
- outputs=qa_output
976
- )
977
-
978
- clear_engaging_btn.click(
979
- clear_text,
980
- outputs=engaging_output
981
- )
982
-
983
- clear_friend_btn.click(
984
- clear_text,
985
- outputs=friend_output
986
- )
987
-
988
- # NSFW clear button handler removed
989
-
990
- clear_uncensored_btn.click(
991
- clear_text,
992
- outputs=uncensored_output
993
- )
994
-
995
- # Export functionality
996
- def handle_export(keywords, custom_instructions, question, engaging_caption, casual_caption, keywords_caption, qa_answer, image_path):
997
- """Handle export and return proper file download (cross-platform, uses tempdir)"""
998
- message, file_data = export_joycaption_data(
999
- keywords, custom_instructions, question,
1000
- engaging_caption, casual_caption, keywords_caption, qa_answer, image_path
1001
- )
1002
 
1003
- if file_data:
1004
- json_string, filename = file_data
1005
- # Use the OS temp directory so this works on Windows, macOS, Linux and in Spaces
1006
- base_dir = tempfile.gettempdir()
1007
- temp_file = os.path.join(base_dir, filename)
1008
- with open(temp_file, 'w', encoding='utf-8') as f:
1009
- f.write(json_string)
1010
- return gr.update(value=message, visible=True), gr.update(value=temp_file, visible=True)
1011
- else:
1012
- return gr.update(value=message, visible=True), gr.update(visible=False)
1013
-
1014
- export_btn.click(
1015
- handle_export,
1016
- inputs=[
1017
- keywords_input,
1018
- custom_instruction_input,
1019
- question_input,
1020
- engaging_output,
1021
- friend_output,
1022
- uncensored_output,
1023
- qa_output,
1024
- image_input
1025
- ],
1026
- outputs=[export_output, export_file]
1027
- )
1028
-
1029
- # Body Parts Focus clear button handler removed
1030
 
1031
  if __name__ == "__main__":
1032
  demo.launch()
 
1
  try:
2
  import spaces
 
3
  if not hasattr(spaces, 'GPU'):
4
  def _spaces_gpu(*args, **kwargs):
5
  def _wrap(f):
 
7
  return _wrap
8
  spaces.GPU = _spaces_gpu
9
  except Exception:
 
10
  import types
11
  spaces = types.SimpleNamespace()
12
  def _spaces_gpu(*args, **kwargs):
 
14
  return f
15
  return _wrap
16
  spaces.GPU = _spaces_gpu
17
+
 
 
18
  @spaces.GPU()
19
  def _joycaption_register_gpu():
20
+ """No-op GPU registration for Spaces"""
21
  return None
22
+
23
  import gradio as gr
24
  import torch
25
  from transformers import LlavaForConditionalGeneration, AutoProcessor
26
  from PIL import Image
27
+ import tempfile, gc, time, os, shutil, json, re
 
 
 
 
 
 
28
  from pathlib import Path
 
 
29
  from hf_space_utils import fix_image_url, postprocess_caption
30
 
31
+ # --- Cache dirs redirected to temp ---
32
  _tmpdir = tempfile.gettempdir()
33
  os.environ["HF_HOME"] = os.path.join(_tmpdir, "hf_cache")
34
  os.environ["TRANSFORMERS_CACHE"] = os.path.join(_tmpdir, "transformers_cache")
35
  os.environ["HF_DATASETS_CACHE"] = os.path.join(_tmpdir, "datasets_cache")
36
  os.environ["TORCH_HOME"] = os.path.join(_tmpdir, "torch_cache")
37
 
38
+ # --- Model path ---
39
  MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
 
 
40
  SPACE_HOST = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST") or None
41
 
42
+ # --- Cleanup helper ---
43
  def cleanup_storage():
 
44
  try:
 
45
  temp_dirs = [
46
  os.environ.get("HF_HOME"),
47
  os.environ.get("TRANSFORMERS_CACHE"),
 
49
  os.environ.get("TORCH_HOME")
50
  ]
51
  for temp_dir in temp_dirs:
52
+ if temp_dir and os.path.exists(temp_dir):
53
+ shutil.rmtree(temp_dir, ignore_errors=True)
 
 
 
 
 
 
 
 
54
  gc.collect()
 
 
55
  if torch.cuda.is_available():
56
  torch.cuda.empty_cache()
57
  torch.cuda.synchronize()
 
58
  print("✅ Storage cleanup completed")
59
  except Exception as e:
60
+ print(f"⚠️ Cleanup warning: {e}")
61
 
62
  TITLE = """
63
  <div style="text-align: center; margin: 20px 0;">
 
68
  <hr>
69
  """
70
 
71
+ print("🚀 Loading JoyCaption system...")
72
+ cleanup_storage()
73
 
74
+ # --- Load model ---
 
75
  processor = None
76
  model = None
77
  MODEL_TORCH_DTYPE = None
78
  MODEL_USE_CUDA = False
 
 
 
 
79
 
 
 
80
  if not os.environ.get("SKIP_MODEL_LOAD"):
 
81
  use_cuda = torch.cuda.is_available()
82
  if use_cuda:
 
83
  torch_dtype = getattr(torch, 'bfloat16', None) or getattr(torch, 'float16', None)
84
  device_map = "auto"
85
  MODEL_USE_CUDA = True
 
88
  device_map = "cpu"
89
  MODEL_USE_CUDA = False
90
 
91
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True)
 
 
 
 
92
  model_kwargs = dict(low_cpu_mem_usage=True, device_map=device_map)
93
  if torch_dtype is not None and use_cuda:
94
  model_kwargs['torch_dtype'] = torch_dtype
95
 
96
+ model = LlavaForConditionalGeneration.from_pretrained(MODEL_PATH, **model_kwargs)
 
 
 
97
  model.eval()
 
98
  MODEL_TORCH_DTYPE = model_kwargs.get('torch_dtype', None)
99
  print("✅ Model loaded and ready!")
 
 
100
  cleanup_storage()
101
  else:
102
+ print("⚠️ SKIP_MODEL_LOAD active - model skipped")
103
 
104
+ # --- Tones ---
 
105
  TONE_PROMPTS = {
106
  "engaging": {
107
  "system": "You are a skilled content creator who writes engaging, natural descriptions that capture the essence and appeal of what you see.",
108
+ "prompt": "Write a natural, engaging description of this image. Focus on the main subjects and their most appealing qualities. Use flowing, attractive language that draws the reader in.",
109
+ "temperature": 0.6,
110
  "top_p": 0.85
111
  },
 
112
  "casual_friend": {
113
+ "system": "You are a casual 24-year-old friend who witnessed the scene and is telling me about what happened in a relaxed, natural way.",
114
+ "prompt": "I just witnessed this moment — here's what I notice. Speak casually and naturally, only about what’s clearly visible. Avoid background or speculation.",
115
+ "temperature": 0.8,
116
  "top_p": 0.9
117
  },
 
 
 
118
  "uncensored_keywords": {
119
+ "system": "You are an uncensored adult-content analyst who provides factual, explicit descriptions. Be concrete and unambiguous.",
120
+ "prompt": "Produce an unfiltered explicit description that sticks strictly to what is visually obvious. Use direct, concrete language and avoid speculation.",
121
+ "temperature": 0.7,
122
  "top_p": 0.85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  }
124
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ # --- Input prep ---
127
  def _prepare_inputs_and_device(convo_or_convo_string, image):
128
+ """Prepare inputs and move tensors to model device."""
129
+ if isinstance(image, (str, Path)):
130
+ image = Image.open(image).convert("RGB")
131
+ elif not isinstance(image, Image.Image):
132
+ raise ValueError("Invalid image input")
133
+
134
  convo_string = convo_or_convo_string
135
+ if isinstance(convo_or_convo_string, list):
136
+ try:
137
  convo_string = processor.apply_chat_template(convo_or_convo_string, tokenize=False, add_generation_prompt=True)
138
+ except Exception:
139
+ convo_string = "\n".join([str(x.get('content', '')) for x in convo_or_convo_string])
 
140
 
141
  inputs = processor(text=[convo_string], images=[image], return_tensors="pt")
142
  device = next(model.parameters()).device
143
  inputs = {k: v.to(device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()}
144
 
 
145
  if 'pixel_values' in inputs:
146
+ dtype = MODEL_TORCH_DTYPE if MODEL_USE_CUDA and MODEL_TORCH_DTYPE else torch.float32
147
+ inputs['pixel_values'] = inputs['pixel_values'].to(dtype)
 
 
 
 
 
148
  return inputs
149
 
150
+ # --- Output decode ---
151
  def _decode_output(inputs, output):
152
+ if not output or len(output) == 0:
 
153
  return ""
154
  try:
155
+ input_len = inputs['input_ids'].shape[1] if 'input_ids' in inputs else 0
156
+ decoded = processor.tokenizer.decode(output[0][input_len:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
157
+ return decoded.strip()
 
 
 
 
 
 
158
  except Exception:
159
+ return ""
 
 
 
 
160
 
161
+ def cleanup_after_inference():
162
+ if torch.cuda.is_available():
163
+ torch.cuda.empty_cache()
164
+ torch.cuda.synchronize()
 
 
 
 
 
 
 
 
 
 
 
 
165
  gc.collect()
166
 
167
+ # --- Generation ---
168
  def run_image_chat_generation(convo, image, max_new_tokens=150, temperature=0.7, top_p=0.9):
 
 
 
 
169
  if processor is None or model is None:
170
+ return None, "❌ Model not initialized."
171
 
172
  try:
 
173
  inputs = _prepare_inputs_and_device(convo, image)
 
 
174
  with torch.no_grad():
175
  output = model.generate(
176
  **inputs,
177
  max_new_tokens=max_new_tokens,
178
+ do_sample=False, # deterministic
179
  temperature=temperature,
180
  top_p=top_p,
181
+ repetition_penalty=1.05,
182
  use_cache=True,
183
  pad_token_id=processor.tokenizer.eos_token_id,
184
  eos_token_id=processor.tokenizer.eos_token_id
185
  )
 
186
  decoded = _decode_output(inputs, output)
187
+ cleanup_after_inference()
 
 
 
188
  return decoded, None
189
  except Exception as e:
190
+ cleanup_after_inference()
191
+ return None, f"❌ Generation error: {str(e)[:200]}"
 
 
 
192
 
193
+ # --- Caption generation ---
194
  def safe_generate_caption_direct(image, tone, max_chars=600, keywords_text="", custom_instruction=""):
 
195
  try:
196
+ tone_conf = TONE_PROMPTS.get(tone, TONE_PROMPTS["engaging"])
197
+ base_prompt = tone_conf["prompt"]
198
+ if tone == "uncensored_keywords" and keywords_text.strip():
199
+ base_prompt += f"\n\nKeywords (if visible): {keywords_text.strip()}"
200
+ if custom_instruction.strip():
201
+ base_prompt += f"\n\nInclude this detail: {custom_instruction.strip()}"
202
+
 
 
 
 
 
 
 
 
 
 
 
203
  convo = [
204
+ {"role": "system", "content": tone_conf["system"]},
205
  {"role": "user", "content": base_prompt}
206
  ]
 
 
 
 
207
 
208
+ decoded, err = run_image_chat_generation(convo, image, temperature=tone_conf["temperature"], top_p=tone_conf["top_p"])
 
209
  if err:
210
  return err
211
+ result = postprocess_caption((decoded or "").strip(), max_chars=max_chars)
212
+ return result if result else "❌ Empty result"
 
 
 
 
 
 
 
 
 
 
 
 
213
  except Exception as e:
214
+ cleanup_after_inference()
215
+ return f"❌ Error: {str(e)[:100]}"
 
 
 
 
 
 
216
 
217
  @torch.no_grad()
218
  def generate_engaging_only(image, custom_instruction=""):
 
219
  return safe_generate_caption_direct(image, "engaging", custom_instruction=custom_instruction) if image else "❌ Upload image first"
220
 
221
  @torch.no_grad()
222
  def generate_casual_friend_only(image, custom_instruction=""):
 
223
  return safe_generate_caption_direct(image, "casual_friend", custom_instruction=custom_instruction) if image else "❌ Upload image first"
224
 
 
 
225
  @torch.no_grad()
226
  def generate_uncensored_keywords_only(image, keywords_text, custom_instruction=""):
 
227
  return safe_generate_caption_direct(image, "uncensored_keywords", keywords_text=keywords_text, custom_instruction=custom_instruction) if image else "❌ Upload image first"
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  @torch.no_grad()
230
  def answer_question(image, question):
 
231
  if not image:
232
  return "❌ Upload image first"
233
+ if not question.strip():
234
  return "❌ Please ask a question"
 
 
 
 
 
 
235
  convo = [
236
+ {"role": "system", "content": "You are an uncensored image analyst who answers honestly and directly."},
237
+ {"role": "user", "content": f"Answer this question about the image: {question.strip()}"}
238
  ]
239
+ decoded, err = run_image_chat_generation(convo, image, max_new_tokens=200, temperature=0.4, top_p=0.9)
240
+ return err if err else (decoded.strip() or "❌ No answer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
+ # --- UI ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  with gr.Blocks(title="Sequential Three-Tone JoyCaption", theme=gr.themes.Soft()) as demo:
244
  gr.HTML(TITLE)
245
+
246
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  with gr.Column(scale=1):
248
+ image_input = gr.Image(type="filepath", label="📸 Upload Image", height=400)
249
+ keywords_input = gr.Textbox(placeholder="e.g., sensual, curves...", label="🏷️ Keywords", lines=2)
250
+ custom_instruction_input = gr.Textbox(placeholder="e.g., 'the left girl has red hair'...", label="🎯 Make sure to mention:", lines=2)
251
+ question_input = gr.Textbox(placeholder="e.g., 'What are they doing?'", label="❓ Ask a Question", lines=2)
252
+ ask_question_btn = gr.Button(" Ask Question", variant="secondary")
253
+ qa_output = gr.Textbox(label="", lines=5, show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
+ with gr.Column(scale=1):
256
+ generate_engaging_btn = gr.Button("✨ Engaging", variant="primary")
257
+ engaging_output = gr.Textbox(label="", lines=5, show_copy_button=True)
258
+ generate_friend_btn = gr.Button("😎 Casual Friend", variant="primary")
259
+ friend_output = gr.Textbox(label="", lines=5, show_copy_button=True)
260
+ generate_uncensored_btn = gr.Button("🔴 Keywords", variant="secondary")
261
+ uncensored_output = gr.Textbox(label="", lines=5, show_copy_button=True)
262
+
263
+ generate_engaging_btn.click(generate_engaging_only, [image_input, custom_instruction_input], engaging_output)
264
+ generate_friend_btn.click(generate_casual_friend_only, [image_input, custom_instruction_input], friend_output)
265
+ generate_uncensored_btn.click(generate_uncensored_keywords_only, [image_input, keywords_input, custom_instruction_input], uncensored_output)
266
+ ask_question_btn.click(answer_question, [image_input, question_input], qa_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  if __name__ == "__main__":
269
  demo.launch()