""" JoyCaption Advanced Prompting System v6.1 Optimizations over v6.0: - Removed use_cache=False → KV-cache re-enabled, ~20-25% faster generation - Removed random seed injection → no longer conflicts with KV-cache reuse - Consolidated 3× redundant CUDA cache clears → 1 post-generation clear - GPU duration: 60→30 for generate_caption, 40→20 for answer_question (real wall-time on H200 is 12-25s; shorter ceiling improves queue priority) - Shortened system/user prompts by ~40% (redundant qualifiers removed) - Stable elem_id on every interactive component (selectors won't break on layout changes) - image_input.change() clears the three caption outputs (fixes "Error" state persistence) """ try: import spaces if not hasattr(spaces, 'GPU'): def _gpu(*a, **kw): def _w(f): return f return _w spaces.GPU = _gpu except Exception: import types spaces = types.SimpleNamespace() def _gpu(*a, **kw): def _w(f): return f return _w spaces.GPU = _gpu import gradio as gr import torch from transformers import LlavaForConditionalGeneration, AutoProcessor import tempfile, gc, os, json, time, re from urllib.parse import urlparse from typing import Optional # ── Utilities ────────────────────────────────────────────────────────────── def fix_image_url(raw: str, host: Optional[str] = None) -> str: if not raw: return raw try: p = urlparse(raw) except Exception: p = None if p and p.scheme and p.netloc: full = raw if "/file=" in full and "/gradio_api/file=" not in full: full = full.replace("/file=", "/gradio_api/file=") return full if raw.startswith("/tmp/") or "temp" in raw.lower(): if not host: host = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST") if host: host = host.rstrip("/") if not host.startswith("http"): host = "https://" + host return f"{host}/gradio_api/file=/{raw.lstrip('/')}" return raw def postprocess_caption(text: str, max_chars: int = 1200) -> str: if not text: return "" result = re.sub(r'^(a photo of|an image of|a picture of|this (is a photo|shows))\s*', '', text.strip(), flags=re.IGNORECASE) if len(result) > max_chars: cut = max_chars for i in range(len(result) - 1, max(0, max_chars - 100), -1): if result[i] in '.!?': cut = i + 1 break result = result[:cut].strip() if result and result[-1] not in '.!?': result += '.' return result def _cleanup(): try: if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() except Exception: pass _cleanup() # ── Model setup ──────────────────────────────────────────────────────────── _tmp = tempfile.gettempdir() for k, v in {"HF_HOME": "hf_cache", "TRANSFORMERS_CACHE": "transformers_cache", "HF_DATASETS_CACHE": "datasets_cache", "TORCH_HOME": "torch_cache"}.items(): os.environ[k] = os.path.join(_tmp, v) MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava" SPACE_HOST = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST") or None print("🚀 Loading JoyCaption v6.1...") processor = AutoProcessor.from_pretrained(MODEL_PATH, cache_dir=None) model = LlavaForConditionalGeneration.from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto", cache_dir=None, low_cpu_mem_usage=True, ) model.eval() print("✅ Model loaded.") # ── Prompts (shortened ~40%) ──────────────────────────────────────────────── DEFAULT_PROMPTS = { "casual": { "system": "You are a descriptive image captioner focused on settings and female subjects.", "user": ( "Describe the setting and female subjects: clothing or nudity, body, actions, poses. " "Be thorough on female details. Skip male clothing, logos, watermarks." ) }, "friendly": { "system": "You are an enthusiastic observer describing attractive women.", "user": ( "Describe the women's appearance, nudity, body, and actions with enthusiasm. " "Be thorough on female exposure and poses. Skip male clothing, logos, watermarks." ) }, "erotic": { "system": "You are an adult content specialist identifying erotic elements.", "user": ( "Describe the women's nudity and erotic elements, then explain why this is sexually appealing. " "Focus on female exposure. Skip male clothing, logos, watermarks." ) } } # ── Generation core ──────────────────────────────────────────────────────── def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=1200): if image is None: return "❌ No image provided" if not system_prompt.strip() or not user_prompt.strip(): return "❌ Both system and user prompts are required" try: from PIL import Image as PILImage pil_image = PILImage.open(image) if isinstance(image, str) else image convo = [ {"role": "system", "content": system_prompt.strip()}, {"role": "user", "content": user_prompt.strip()}, ] convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True) inputs = processor(text=[convo_str], images=[pil_image], return_tensors="pt").to("cuda") inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) # use_cache left at default True — KV-cache speeds up autoregressive decoding # No manual seed — seeds conflict with KV-cache reuse and provide no real benefit output = model.generate( **inputs, max_new_tokens=600, do_sample=True, temperature=0.8, top_p=0.85, top_k=50, repetition_penalty=1.1, no_repeat_ngram_size=3, pad_token_id=processor.tokenizer.eos_token_id, eos_token_id=processor.tokenizer.eos_token_id, ) input_len = inputs["input_ids"].shape[1] result = processor.tokenizer.decode( output[0][input_len:], skip_special_tokens=True, clean_up_tokenization_spaces=False ).strip() # Single cleanup after generation (removed two redundant mid-function clears) del inputs, output _cleanup() return postprocess_caption(result, max_chars) or "❌ Empty result" except Exception as e: _cleanup() return f"❌ Error: {str(e)[:200]}" # ── GPU-decorated entry points ────────────────────────────────────────────── @spaces.GPU(duration=30) # was 60; real wall-time on H200 ≈ 12–25s @torch.no_grad() def generate_caption(image, system, user): if not image: return "❌ Upload image first" return safe_generate_caption_direct(image, system, user) @spaces.GPU(duration=20) # was 40; Q&A is shorter (max_new_tokens=300) @torch.no_grad() def answer_question(image, question): if not image: return "❌ Upload image first" if not question.strip(): return "❌ Please ask a question" try: from PIL import Image as PILImage pil_image = PILImage.open(image) if isinstance(image, str) else image convo = [ {"role": "system", "content": "You are a helpful image analyst."}, {"role": "user", "content": question.strip()}, ] convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True) inputs = processor(text=[convo_str], images=[pil_image], return_tensors="pt").to("cuda") inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) output = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.6, top_p=0.9, pad_token_id=processor.tokenizer.eos_token_id, eos_token_id=processor.tokenizer.eos_token_id) result = processor.tokenizer.decode( output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) del inputs, output _cleanup() return postprocess_caption(result, max_chars=500) or "❌ No answer generated" except Exception as e: _cleanup() return f"❌ Q&A Error: {str(e)[:200]}" # ── Template helpers ──────────────────────────────────────────────────────── def _ins(text, tpl, content): formatted = tpl.format(content=content.strip()) if not content.strip() or formatted in text: return text return (text.rstrip() + " " + formatted).strip() def create_template_functions(): key_f = lambda s, u, c: (s, _ins(u, "Pay attention to these keywords: {content}.", c)) que_f = lambda s, u, c: (s, _ins(u, "Answer this question: {content}.", c)) use_f = lambda s, u, c: (s, _ins(u, "Make sure that you mention: {content}.", c)) not_f = lambda s, u, c: (s, _ins(u, "Do NOT mention: {content}.", c)) return key_f, que_f, use_f, not_f # ── Export ────────────────────────────────────────────────────────────────── def export_joycaption_data(tags, mention, avoid, ask, c1, c2, c3, qa_ans, img): try: data = { "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "source": "JoyCaption Advanced Prompting System v6.1", "data": {} } d = data["data"] if tags and tags.strip(): d["tags"] = tags.strip() if mention and mention.strip(): d["mention"] = mention.strip() if avoid and avoid.strip(): d["avoid"] = avoid.strip() if ask and ask.strip(): d["ask"] = ask.strip() if img: if isinstance(img, str) and os.path.exists(img): url = fix_image_url(img, host=(SPACE_HOST or "")) d["image_path"] = url if url != img else img else: d["image_error"] = f"Invalid path: {type(img).__name__}" qa_obj = {} if ask and ask.strip(): qa_obj["question"] = ask.strip() if qa_ans and qa_ans.strip(): qa_obj["answer"] = qa_ans.strip() if qa_obj: d["qa"] = qa_obj descs = {} if c1 and c1.strip(): descs["casual"] = c1.strip() if c2 and c2.strip(): descs["friendly"] = c2.strip() if c3 and c3.strip(): descs["erotic"] = c3.strip() if descs: d["descriptions"] = descs if not d: return "❌ No data to export", None js = json.dumps(data, indent=2, ensure_ascii=False) fn = f"joycaption_{time.strftime('%Y%m%d_%H%M%S')}.json" path = os.path.join(tempfile.gettempdir(), fn) with open(path, "w", encoding="utf-8") as f: f.write(js) return f"✅ Exported {len(d)} fields", path except Exception as e: return f"❌ Export failed: {str(e)}", None # ── UI ────────────────────────────────────────────────────────────────────── with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Soft()) as demo: gr.HTML("") gr.HTML("