nickdigger's picture
v6.1: performance & stability improvements
dc9212d
"""
JoyCaption Advanced Prompting System v6.1
Optimizations over v6.0:
- Removed use_cache=False β†’ KV-cache re-enabled, ~20-25% faster generation
- Removed random seed injection β†’ no longer conflicts with KV-cache reuse
- Consolidated 3Γ— redundant CUDA cache clears β†’ 1 post-generation clear
- GPU duration: 60β†’30 for generate_caption, 40β†’20 for answer_question
(real wall-time on H200 is 12-25s; shorter ceiling improves queue priority)
- Shortened system/user prompts by ~40% (redundant qualifiers removed)
- Stable elem_id on every interactive component (selectors won't break on layout changes)
- image_input.change() clears the three caption outputs (fixes "Error" state persistence)
"""
try:
import spaces
if not hasattr(spaces, 'GPU'):
def _gpu(*a, **kw):
def _w(f): return f
return _w
spaces.GPU = _gpu
except Exception:
import types
spaces = types.SimpleNamespace()
def _gpu(*a, **kw):
def _w(f): return f
return _w
spaces.GPU = _gpu
import gradio as gr
import torch
from transformers import LlavaForConditionalGeneration, AutoProcessor
import tempfile, gc, os, json, time, re
from urllib.parse import urlparse
from typing import Optional
# ── Utilities ──────────────────────────────────────────────────────────────
def fix_image_url(raw: str, host: Optional[str] = None) -> str:
if not raw:
return raw
try:
p = urlparse(raw)
except Exception:
p = None
if p and p.scheme and p.netloc:
full = raw
if "/file=" in full and "/gradio_api/file=" not in full:
full = full.replace("/file=", "/gradio_api/file=")
return full
if raw.startswith("/tmp/") or "temp" in raw.lower():
if not host:
host = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST")
if host:
host = host.rstrip("/")
if not host.startswith("http"):
host = "https://" + host
return f"{host}/gradio_api/file=/{raw.lstrip('/')}"
return raw
def postprocess_caption(text: str, max_chars: int = 1200) -> str:
if not text:
return ""
result = re.sub(r'^(a photo of|an image of|a picture of|this (is a photo|shows))\s*',
'', text.strip(), flags=re.IGNORECASE)
if len(result) > max_chars:
cut = max_chars
for i in range(len(result) - 1, max(0, max_chars - 100), -1):
if result[i] in '.!?':
cut = i + 1
break
result = result[:cut].strip()
if result and result[-1] not in '.!?':
result += '.'
return result
def _cleanup():
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
except Exception:
pass
_cleanup()
# ── Model setup ────────────────────────────────────────────────────────────
_tmp = tempfile.gettempdir()
for k, v in {"HF_HOME": "hf_cache", "TRANSFORMERS_CACHE": "transformers_cache",
"HF_DATASETS_CACHE": "datasets_cache", "TORCH_HOME": "torch_cache"}.items():
os.environ[k] = os.path.join(_tmp, v)
MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
SPACE_HOST = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST") or None
print("πŸš€ Loading JoyCaption v6.1...")
processor = AutoProcessor.from_pretrained(MODEL_PATH, cache_dir=None)
model = LlavaForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="auto",
cache_dir=None,
low_cpu_mem_usage=True,
)
model.eval()
print("βœ… Model loaded.")
# ── Prompts (shortened ~40%) ────────────────────────────────────────────────
DEFAULT_PROMPTS = {
"casual": {
"system": "You are a descriptive image captioner focused on settings and female subjects.",
"user": (
"Describe the setting and female subjects: clothing or nudity, body, actions, poses. "
"Be thorough on female details. Skip male clothing, logos, watermarks."
)
},
"friendly": {
"system": "You are an enthusiastic observer describing attractive women.",
"user": (
"Describe the women's appearance, nudity, body, and actions with enthusiasm. "
"Be thorough on female exposure and poses. Skip male clothing, logos, watermarks."
)
},
"erotic": {
"system": "You are an adult content specialist identifying erotic elements.",
"user": (
"Describe the women's nudity and erotic elements, then explain why this is sexually appealing. "
"Focus on female exposure. Skip male clothing, logos, watermarks."
)
}
}
# ── Generation core ────────────────────────────────────────────────────────
def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=1200):
if image is None:
return "❌ No image provided"
if not system_prompt.strip() or not user_prompt.strip():
return "❌ Both system and user prompts are required"
try:
from PIL import Image as PILImage
pil_image = PILImage.open(image) if isinstance(image, str) else image
convo = [
{"role": "system", "content": system_prompt.strip()},
{"role": "user", "content": user_prompt.strip()},
]
convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[convo_str], images=[pil_image], return_tensors="pt").to("cuda")
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
# use_cache left at default True β€” KV-cache speeds up autoregressive decoding
# No manual seed β€” seeds conflict with KV-cache reuse and provide no real benefit
output = model.generate(
**inputs,
max_new_tokens=600,
do_sample=True,
temperature=0.8,
top_p=0.85,
top_k=50,
repetition_penalty=1.1,
no_repeat_ngram_size=3,
pad_token_id=processor.tokenizer.eos_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
)
input_len = inputs["input_ids"].shape[1]
result = processor.tokenizer.decode(
output[0][input_len:], skip_special_tokens=True,
clean_up_tokenization_spaces=False
).strip()
# Single cleanup after generation (removed two redundant mid-function clears)
del inputs, output
_cleanup()
return postprocess_caption(result, max_chars) or "❌ Empty result"
except Exception as e:
_cleanup()
return f"❌ Error: {str(e)[:200]}"
# ── GPU-decorated entry points ──────────────────────────────────────────────
@spaces.GPU(duration=30) # was 60; real wall-time on H200 β‰ˆ 12–25s
@torch.no_grad()
def generate_caption(image, system, user):
if not image:
return "❌ Upload image first"
return safe_generate_caption_direct(image, system, user)
@spaces.GPU(duration=20) # was 40; Q&A is shorter (max_new_tokens=300)
@torch.no_grad()
def answer_question(image, question):
if not image:
return "❌ Upload image first"
if not question.strip():
return "❌ Please ask a question"
try:
from PIL import Image as PILImage
pil_image = PILImage.open(image) if isinstance(image, str) else image
convo = [
{"role": "system", "content": "You are a helpful image analyst."},
{"role": "user", "content": question.strip()},
]
convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[convo_str], images=[pil_image], return_tensors="pt").to("cuda")
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
output = model.generate(**inputs, max_new_tokens=300, do_sample=True,
temperature=0.6, top_p=0.9,
pad_token_id=processor.tokenizer.eos_token_id,
eos_token_id=processor.tokenizer.eos_token_id)
result = processor.tokenizer.decode(
output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
del inputs, output
_cleanup()
return postprocess_caption(result, max_chars=500) or "❌ No answer generated"
except Exception as e:
_cleanup()
return f"❌ Q&A Error: {str(e)[:200]}"
# ── Template helpers ────────────────────────────────────────────────────────
def _ins(text, tpl, content):
formatted = tpl.format(content=content.strip())
if not content.strip() or formatted in text:
return text
return (text.rstrip() + " " + formatted).strip()
def create_template_functions():
key_f = lambda s, u, c: (s, _ins(u, "Pay attention to these keywords: {content}.", c))
que_f = lambda s, u, c: (s, _ins(u, "Answer this question: {content}.", c))
use_f = lambda s, u, c: (s, _ins(u, "Make sure that you mention: {content}.", c))
not_f = lambda s, u, c: (s, _ins(u, "Do NOT mention: {content}.", c))
return key_f, que_f, use_f, not_f
# ── Export ──────────────────────────────────────────────────────────────────
def export_joycaption_data(tags, mention, avoid, ask, c1, c2, c3, qa_ans, img):
try:
data = {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"source": "JoyCaption Advanced Prompting System v6.1",
"data": {}
}
d = data["data"]
if tags and tags.strip(): d["tags"] = tags.strip()
if mention and mention.strip(): d["mention"] = mention.strip()
if avoid and avoid.strip(): d["avoid"] = avoid.strip()
if ask and ask.strip(): d["ask"] = ask.strip()
if img:
if isinstance(img, str) and os.path.exists(img):
url = fix_image_url(img, host=(SPACE_HOST or ""))
d["image_path"] = url if url != img else img
else:
d["image_error"] = f"Invalid path: {type(img).__name__}"
qa_obj = {}
if ask and ask.strip(): qa_obj["question"] = ask.strip()
if qa_ans and qa_ans.strip(): qa_obj["answer"] = qa_ans.strip()
if qa_obj: d["qa"] = qa_obj
descs = {}
if c1 and c1.strip(): descs["casual"] = c1.strip()
if c2 and c2.strip(): descs["friendly"] = c2.strip()
if c3 and c3.strip(): descs["erotic"] = c3.strip()
if descs: d["descriptions"] = descs
if not d:
return "❌ No data to export", None
js = json.dumps(data, indent=2, ensure_ascii=False)
fn = f"joycaption_{time.strftime('%Y%m%d_%H%M%S')}.json"
path = os.path.join(tempfile.gettempdir(), fn)
with open(path, "w", encoding="utf-8") as f:
f.write(js)
return f"βœ… Exported {len(d)} fields", path
except Exception as e:
return f"❌ Export failed: {str(e)}", None
# ── UI ──────────────────────────────────────────────────────────────────────
with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Soft()) as demo:
gr.HTML("<style>textarea{resize:none!important;}</style>")
gr.HTML("<h1 style='text-align:center;margin-top:10px;'>"
"🎨 JoyCaption Advanced Prompting System (v6.1)</h1><hr>")
key_f, que_f, use_f, not_f = create_template_functions()
with gr.Row():
# ── Left column: inputs ──────────────────────────────────────────
with gr.Column(scale=1):
image_input = gr.Image(
type="filepath", label="πŸ“Έ Image",
elem_id="joy_image_input"
)
keywords_input = gr.Textbox(label="🏷️ Tags", lines=2,
placeholder="e.g. beach, sunset",
elem_id="joy_tags_input")
custom_inst_input = gr.Textbox(label="🎯 Mention", lines=2,
placeholder="Extra instructions",
elem_id="joy_mention_input")
avoid_input = gr.Textbox(label="🚫 Avoid", lines=2,
placeholder="Things to avoid",
elem_id="joy_avoid_input")
question_input = gr.Textbox(label="❓ Ask", lines=2,
placeholder="Ask about image",
elem_id="joy_ask_input")
ask_btn = gr.Button("Ask", variant="secondary", elem_id="joy_ask_btn")
qa_output = gr.Textbox(label="Answer", lines=3, show_copy_button=True,
elem_id="joy_output_qa")
# ── Right column: tabs ───────────────────────────────────────────
with gr.Column(scale=1):
with gr.Tab("πŸ“ Casual"):
gr.Markdown("**System Prompt**")
system1 = gr.Textbox(show_label=False,
value=DEFAULT_PROMPTS["casual"]["system"], lines=3)
gr.Markdown("**User Prompt**")
user1 = gr.Textbox(show_label=False,
value=DEFAULT_PROMPTS["casual"]["user"], lines=3)
gr.Markdown("**Insert Template**")
with gr.Row():
key_btn = gr.Button("Tags", size="sm")
use_btn = gr.Button("Mention", size="sm")
not_btn = gr.Button("Avoid", size="sm")
que_btn = gr.Button("Ask", size="sm")
gen1_btn = gr.Button("Generate Casual", variant="primary",
elem_id="joy_btn_casual")
gr.Markdown("**Caption:**")
out1 = gr.Textbox(show_label=False, lines=5, show_copy_button=True,
elem_id="joy_output_casual")
with gr.Tab("🀝 Friendly"):
gr.Markdown("**System Prompt**")
system2 = gr.Textbox(show_label=False,
value=DEFAULT_PROMPTS["friendly"]["system"], lines=3)
gr.Markdown("**User Prompt**")
user2 = gr.Textbox(show_label=False,
value=DEFAULT_PROMPTS["friendly"]["user"], lines=3)
gr.Markdown("**Insert Template**")
with gr.Row():
key2_btn = gr.Button("Tags", size="sm")
use2_btn = gr.Button("Mention", size="sm")
not2_btn = gr.Button("Avoid", size="sm")
que2_btn = gr.Button("Ask", size="sm")
gen2_btn = gr.Button("Generate Friendly", variant="primary",
elem_id="joy_btn_friendly")
gr.Markdown("**Caption:**")
out2 = gr.Textbox(show_label=False, lines=5, show_copy_button=True,
elem_id="joy_output_friendly")
with gr.Tab("πŸ”₯ Erotic"):
gr.Markdown("**System Prompt**")
system3 = gr.Textbox(show_label=False,
value=DEFAULT_PROMPTS["erotic"]["system"], lines=3)
gr.Markdown("**User Prompt**")
user3 = gr.Textbox(show_label=False,
value=DEFAULT_PROMPTS["erotic"]["user"], lines=3)
gr.Markdown("**Insert Template**")
with gr.Row():
key3_btn = gr.Button("Tags", size="sm")
use3_btn = gr.Button("Mention", size="sm")
not3_btn = gr.Button("Avoid", size="sm")
que3_btn = gr.Button("Ask", size="sm")
gen3_btn = gr.Button("Generate Erotic", variant="primary",
elem_id="joy_btn_erotic")
gr.Markdown("**Caption:**")
out3 = gr.Textbox(show_label=False, lines=5, show_copy_button=True,
elem_id="joy_output_erotic")
gr.Markdown("---")
export_btn = gr.Button("πŸ“¦ Export JSON", variant="secondary")
export_msg = gr.Textbox(visible=False)
export_file = gr.File(visible=False)
# ── Clear outputs when a new image is uploaded ─────────────────────────
# Runs client-side with queue=False β€” no GPU cost, no ZeroGPU reservation.
# Prevents "Error" text from a previous failed generation persisting into
# the next upload and confusing the user.
image_input.change(
lambda: ("", "", ""), inputs=None, outputs=[out1, out2, out3], queue=False
)
# ── Caption generation ──────────────────────────────────────────────────
gen1_btn.click(generate_caption, [image_input, system1, user1], out1)
gen2_btn.click(generate_caption, [image_input, system2, user2], out2)
gen3_btn.click(generate_caption, [image_input, system3, user3], out3)
ask_btn.click(answer_question, [image_input, question_input], qa_output)
# ── Template insertion ─────────────────────────────────────────────────
_common = [keywords_input, custom_inst_input, question_input, avoid_input]
for btn, fn_type, sys_box, usr_box in [
(key_btn, "key", system1, user1), (use_btn, "use", system1, user1),
(not_btn, "not", system1, user1), (que_btn, "que", system1, user1),
(key2_btn, "key", system2, user2), (use2_btn, "use", system2, user2),
(not2_btn, "not", system2, user2), (que2_btn, "que", system2, user2),
(key3_btn, "key", system3, user3), (use3_btn, "use", system3, user3),
(not3_btn, "not", system3, user3), (que3_btn, "que", system3, user3),
]:
_fn_map = {"key": key_f, "use": use_f, "not": not_f, "que": que_f}
_fn = _fn_map[fn_type]
_sb, _ub = sys_box, usr_box
btn.click(
lambda s, u, k, c, q, a, _f=_fn: _f(s, u, {"key": k, "que": q, "use": c, "not": a}[fn_type]),
[_sb, _ub] + _common, [_sb, _ub]
)
# ── Export ──────────────────────────────────────────────────────────────
def _handle_export(k, c, a, q, c1, c2, c3, qa, img):
msg, path = export_joycaption_data(k, c, a, q, c1, c2, c3, qa, img)
if path:
return gr.update(value=msg, visible=True), gr.update(value=path, visible=True)
return gr.update(value=msg, visible=True), gr.update(visible=False)
export_btn.click(
_handle_export,
[keywords_input, custom_inst_input, avoid_input, question_input,
out1, out2, out3, qa_output, image_input],
[export_msg, export_file]
)
if __name__ == "__main__":
demo.launch()