Spaces:
Running on Zero
Running on Zero
| """ | |
| JoyCaption Advanced Prompting System v6.1 | |
| Optimizations over v6.0: | |
| - Removed use_cache=False β KV-cache re-enabled, ~20-25% faster generation | |
| - Removed random seed injection β no longer conflicts with KV-cache reuse | |
| - Consolidated 3Γ redundant CUDA cache clears β 1 post-generation clear | |
| - GPU duration: 60β30 for generate_caption, 40β20 for answer_question | |
| (real wall-time on H200 is 12-25s; shorter ceiling improves queue priority) | |
| - Shortened system/user prompts by ~40% (redundant qualifiers removed) | |
| - Stable elem_id on every interactive component (selectors won't break on layout changes) | |
| - image_input.change() clears the three caption outputs (fixes "Error" state persistence) | |
| """ | |
| try: | |
| import spaces | |
| if not hasattr(spaces, 'GPU'): | |
| def _gpu(*a, **kw): | |
| def _w(f): return f | |
| return _w | |
| spaces.GPU = _gpu | |
| except Exception: | |
| import types | |
| spaces = types.SimpleNamespace() | |
| def _gpu(*a, **kw): | |
| def _w(f): return f | |
| return _w | |
| spaces.GPU = _gpu | |
| import gradio as gr | |
| import torch | |
| from transformers import LlavaForConditionalGeneration, AutoProcessor | |
| import tempfile, gc, os, json, time, re | |
| from urllib.parse import urlparse | |
| from typing import Optional | |
| # ββ Utilities ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fix_image_url(raw: str, host: Optional[str] = None) -> str: | |
| if not raw: | |
| return raw | |
| try: | |
| p = urlparse(raw) | |
| except Exception: | |
| p = None | |
| if p and p.scheme and p.netloc: | |
| full = raw | |
| if "/file=" in full and "/gradio_api/file=" not in full: | |
| full = full.replace("/file=", "/gradio_api/file=") | |
| return full | |
| if raw.startswith("/tmp/") or "temp" in raw.lower(): | |
| if not host: | |
| host = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST") | |
| if host: | |
| host = host.rstrip("/") | |
| if not host.startswith("http"): | |
| host = "https://" + host | |
| return f"{host}/gradio_api/file=/{raw.lstrip('/')}" | |
| return raw | |
| def postprocess_caption(text: str, max_chars: int = 1200) -> str: | |
| if not text: | |
| return "" | |
| result = re.sub(r'^(a photo of|an image of|a picture of|this (is a photo|shows))\s*', | |
| '', text.strip(), flags=re.IGNORECASE) | |
| if len(result) > max_chars: | |
| cut = max_chars | |
| for i in range(len(result) - 1, max(0, max_chars - 100), -1): | |
| if result[i] in '.!?': | |
| cut = i + 1 | |
| break | |
| result = result[:cut].strip() | |
| if result and result[-1] not in '.!?': | |
| result += '.' | |
| return result | |
| def _cleanup(): | |
| try: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| except Exception: | |
| pass | |
| _cleanup() | |
| # ββ Model setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _tmp = tempfile.gettempdir() | |
| for k, v in {"HF_HOME": "hf_cache", "TRANSFORMERS_CACHE": "transformers_cache", | |
| "HF_DATASETS_CACHE": "datasets_cache", "TORCH_HOME": "torch_cache"}.items(): | |
| os.environ[k] = os.path.join(_tmp, v) | |
| MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava" | |
| SPACE_HOST = os.environ.get("SPACE_HOST") or os.environ.get("HF_SPACE_HOST") or None | |
| print("π Loading JoyCaption v6.1...") | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH, cache_dir=None) | |
| model = LlavaForConditionalGeneration.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| cache_dir=None, | |
| low_cpu_mem_usage=True, | |
| ) | |
| model.eval() | |
| print("β Model loaded.") | |
| # ββ Prompts (shortened ~40%) ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEFAULT_PROMPTS = { | |
| "casual": { | |
| "system": "You are a descriptive image captioner focused on settings and female subjects.", | |
| "user": ( | |
| "Describe the setting and female subjects: clothing or nudity, body, actions, poses. " | |
| "Be thorough on female details. Skip male clothing, logos, watermarks." | |
| ) | |
| }, | |
| "friendly": { | |
| "system": "You are an enthusiastic observer describing attractive women.", | |
| "user": ( | |
| "Describe the women's appearance, nudity, body, and actions with enthusiasm. " | |
| "Be thorough on female exposure and poses. Skip male clothing, logos, watermarks." | |
| ) | |
| }, | |
| "erotic": { | |
| "system": "You are an adult content specialist identifying erotic elements.", | |
| "user": ( | |
| "Describe the women's nudity and erotic elements, then explain why this is sexually appealing. " | |
| "Focus on female exposure. Skip male clothing, logos, watermarks." | |
| ) | |
| } | |
| } | |
| # ββ Generation core ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def safe_generate_caption_direct(image, system_prompt, user_prompt, max_chars=1200): | |
| if image is None: | |
| return "β No image provided" | |
| if not system_prompt.strip() or not user_prompt.strip(): | |
| return "β Both system and user prompts are required" | |
| try: | |
| from PIL import Image as PILImage | |
| pil_image = PILImage.open(image) if isinstance(image, str) else image | |
| convo = [ | |
| {"role": "system", "content": system_prompt.strip()}, | |
| {"role": "user", "content": user_prompt.strip()}, | |
| ] | |
| convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=[convo_str], images=[pil_image], return_tensors="pt").to("cuda") | |
| inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) | |
| # use_cache left at default True β KV-cache speeds up autoregressive decoding | |
| # No manual seed β seeds conflict with KV-cache reuse and provide no real benefit | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=600, | |
| do_sample=True, | |
| temperature=0.8, | |
| top_p=0.85, | |
| top_k=50, | |
| repetition_penalty=1.1, | |
| no_repeat_ngram_size=3, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| ) | |
| input_len = inputs["input_ids"].shape[1] | |
| result = processor.tokenizer.decode( | |
| output[0][input_len:], skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| ).strip() | |
| # Single cleanup after generation (removed two redundant mid-function clears) | |
| del inputs, output | |
| _cleanup() | |
| return postprocess_caption(result, max_chars) or "β Empty result" | |
| except Exception as e: | |
| _cleanup() | |
| return f"β Error: {str(e)[:200]}" | |
| # ββ GPU-decorated entry points ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # was 60; real wall-time on H200 β 12β25s | |
| def generate_caption(image, system, user): | |
| if not image: | |
| return "β Upload image first" | |
| return safe_generate_caption_direct(image, system, user) | |
| # was 40; Q&A is shorter (max_new_tokens=300) | |
| def answer_question(image, question): | |
| if not image: | |
| return "β Upload image first" | |
| if not question.strip(): | |
| return "β Please ask a question" | |
| try: | |
| from PIL import Image as PILImage | |
| pil_image = PILImage.open(image) if isinstance(image, str) else image | |
| convo = [ | |
| {"role": "system", "content": "You are a helpful image analyst."}, | |
| {"role": "user", "content": question.strip()}, | |
| ] | |
| convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=[convo_str], images=[pil_image], return_tensors="pt").to("cuda") | |
| inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) | |
| output = model.generate(**inputs, max_new_tokens=300, do_sample=True, | |
| temperature=0.6, top_p=0.9, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id) | |
| result = processor.tokenizer.decode( | |
| output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) | |
| del inputs, output | |
| _cleanup() | |
| return postprocess_caption(result, max_chars=500) or "β No answer generated" | |
| except Exception as e: | |
| _cleanup() | |
| return f"β Q&A Error: {str(e)[:200]}" | |
| # ββ Template helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _ins(text, tpl, content): | |
| formatted = tpl.format(content=content.strip()) | |
| if not content.strip() or formatted in text: | |
| return text | |
| return (text.rstrip() + " " + formatted).strip() | |
| def create_template_functions(): | |
| key_f = lambda s, u, c: (s, _ins(u, "Pay attention to these keywords: {content}.", c)) | |
| que_f = lambda s, u, c: (s, _ins(u, "Answer this question: {content}.", c)) | |
| use_f = lambda s, u, c: (s, _ins(u, "Make sure that you mention: {content}.", c)) | |
| not_f = lambda s, u, c: (s, _ins(u, "Do NOT mention: {content}.", c)) | |
| return key_f, que_f, use_f, not_f | |
| # ββ Export ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def export_joycaption_data(tags, mention, avoid, ask, c1, c2, c3, qa_ans, img): | |
| try: | |
| data = { | |
| "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |
| "source": "JoyCaption Advanced Prompting System v6.1", | |
| "data": {} | |
| } | |
| d = data["data"] | |
| if tags and tags.strip(): d["tags"] = tags.strip() | |
| if mention and mention.strip(): d["mention"] = mention.strip() | |
| if avoid and avoid.strip(): d["avoid"] = avoid.strip() | |
| if ask and ask.strip(): d["ask"] = ask.strip() | |
| if img: | |
| if isinstance(img, str) and os.path.exists(img): | |
| url = fix_image_url(img, host=(SPACE_HOST or "")) | |
| d["image_path"] = url if url != img else img | |
| else: | |
| d["image_error"] = f"Invalid path: {type(img).__name__}" | |
| qa_obj = {} | |
| if ask and ask.strip(): qa_obj["question"] = ask.strip() | |
| if qa_ans and qa_ans.strip(): qa_obj["answer"] = qa_ans.strip() | |
| if qa_obj: d["qa"] = qa_obj | |
| descs = {} | |
| if c1 and c1.strip(): descs["casual"] = c1.strip() | |
| if c2 and c2.strip(): descs["friendly"] = c2.strip() | |
| if c3 and c3.strip(): descs["erotic"] = c3.strip() | |
| if descs: d["descriptions"] = descs | |
| if not d: | |
| return "β No data to export", None | |
| js = json.dumps(data, indent=2, ensure_ascii=False) | |
| fn = f"joycaption_{time.strftime('%Y%m%d_%H%M%S')}.json" | |
| path = os.path.join(tempfile.gettempdir(), fn) | |
| with open(path, "w", encoding="utf-8") as f: | |
| f.write(js) | |
| return f"β Exported {len(d)} fields", path | |
| except Exception as e: | |
| return f"β Export failed: {str(e)}", None | |
| # ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="JoyCaption Advanced Prompting System", theme=gr.themes.Soft()) as demo: | |
| gr.HTML("<style>textarea{resize:none!important;}</style>") | |
| gr.HTML("<h1 style='text-align:center;margin-top:10px;'>" | |
| "π¨ JoyCaption Advanced Prompting System (v6.1)</h1><hr>") | |
| key_f, que_f, use_f, not_f = create_template_functions() | |
| with gr.Row(): | |
| # ββ Left column: inputs ββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| type="filepath", label="πΈ Image", | |
| elem_id="joy_image_input" | |
| ) | |
| keywords_input = gr.Textbox(label="π·οΈ Tags", lines=2, | |
| placeholder="e.g. beach, sunset", | |
| elem_id="joy_tags_input") | |
| custom_inst_input = gr.Textbox(label="π― Mention", lines=2, | |
| placeholder="Extra instructions", | |
| elem_id="joy_mention_input") | |
| avoid_input = gr.Textbox(label="π« Avoid", lines=2, | |
| placeholder="Things to avoid", | |
| elem_id="joy_avoid_input") | |
| question_input = gr.Textbox(label="β Ask", lines=2, | |
| placeholder="Ask about image", | |
| elem_id="joy_ask_input") | |
| ask_btn = gr.Button("Ask", variant="secondary", elem_id="joy_ask_btn") | |
| qa_output = gr.Textbox(label="Answer", lines=3, show_copy_button=True, | |
| elem_id="joy_output_qa") | |
| # ββ Right column: tabs βββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| with gr.Tab("π Casual"): | |
| gr.Markdown("**System Prompt**") | |
| system1 = gr.Textbox(show_label=False, | |
| value=DEFAULT_PROMPTS["casual"]["system"], lines=3) | |
| gr.Markdown("**User Prompt**") | |
| user1 = gr.Textbox(show_label=False, | |
| value=DEFAULT_PROMPTS["casual"]["user"], lines=3) | |
| gr.Markdown("**Insert Template**") | |
| with gr.Row(): | |
| key_btn = gr.Button("Tags", size="sm") | |
| use_btn = gr.Button("Mention", size="sm") | |
| not_btn = gr.Button("Avoid", size="sm") | |
| que_btn = gr.Button("Ask", size="sm") | |
| gen1_btn = gr.Button("Generate Casual", variant="primary", | |
| elem_id="joy_btn_casual") | |
| gr.Markdown("**Caption:**") | |
| out1 = gr.Textbox(show_label=False, lines=5, show_copy_button=True, | |
| elem_id="joy_output_casual") | |
| with gr.Tab("π€ Friendly"): | |
| gr.Markdown("**System Prompt**") | |
| system2 = gr.Textbox(show_label=False, | |
| value=DEFAULT_PROMPTS["friendly"]["system"], lines=3) | |
| gr.Markdown("**User Prompt**") | |
| user2 = gr.Textbox(show_label=False, | |
| value=DEFAULT_PROMPTS["friendly"]["user"], lines=3) | |
| gr.Markdown("**Insert Template**") | |
| with gr.Row(): | |
| key2_btn = gr.Button("Tags", size="sm") | |
| use2_btn = gr.Button("Mention", size="sm") | |
| not2_btn = gr.Button("Avoid", size="sm") | |
| que2_btn = gr.Button("Ask", size="sm") | |
| gen2_btn = gr.Button("Generate Friendly", variant="primary", | |
| elem_id="joy_btn_friendly") | |
| gr.Markdown("**Caption:**") | |
| out2 = gr.Textbox(show_label=False, lines=5, show_copy_button=True, | |
| elem_id="joy_output_friendly") | |
| with gr.Tab("π₯ Erotic"): | |
| gr.Markdown("**System Prompt**") | |
| system3 = gr.Textbox(show_label=False, | |
| value=DEFAULT_PROMPTS["erotic"]["system"], lines=3) | |
| gr.Markdown("**User Prompt**") | |
| user3 = gr.Textbox(show_label=False, | |
| value=DEFAULT_PROMPTS["erotic"]["user"], lines=3) | |
| gr.Markdown("**Insert Template**") | |
| with gr.Row(): | |
| key3_btn = gr.Button("Tags", size="sm") | |
| use3_btn = gr.Button("Mention", size="sm") | |
| not3_btn = gr.Button("Avoid", size="sm") | |
| que3_btn = gr.Button("Ask", size="sm") | |
| gen3_btn = gr.Button("Generate Erotic", variant="primary", | |
| elem_id="joy_btn_erotic") | |
| gr.Markdown("**Caption:**") | |
| out3 = gr.Textbox(show_label=False, lines=5, show_copy_button=True, | |
| elem_id="joy_output_erotic") | |
| gr.Markdown("---") | |
| export_btn = gr.Button("π¦ Export JSON", variant="secondary") | |
| export_msg = gr.Textbox(visible=False) | |
| export_file = gr.File(visible=False) | |
| # ββ Clear outputs when a new image is uploaded βββββββββββββββββββββββββ | |
| # Runs client-side with queue=False β no GPU cost, no ZeroGPU reservation. | |
| # Prevents "Error" text from a previous failed generation persisting into | |
| # the next upload and confusing the user. | |
| image_input.change( | |
| lambda: ("", "", ""), inputs=None, outputs=[out1, out2, out3], queue=False | |
| ) | |
| # ββ Caption generation ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| gen1_btn.click(generate_caption, [image_input, system1, user1], out1) | |
| gen2_btn.click(generate_caption, [image_input, system2, user2], out2) | |
| gen3_btn.click(generate_caption, [image_input, system3, user3], out3) | |
| ask_btn.click(answer_question, [image_input, question_input], qa_output) | |
| # ββ Template insertion βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _common = [keywords_input, custom_inst_input, question_input, avoid_input] | |
| for btn, fn_type, sys_box, usr_box in [ | |
| (key_btn, "key", system1, user1), (use_btn, "use", system1, user1), | |
| (not_btn, "not", system1, user1), (que_btn, "que", system1, user1), | |
| (key2_btn, "key", system2, user2), (use2_btn, "use", system2, user2), | |
| (not2_btn, "not", system2, user2), (que2_btn, "que", system2, user2), | |
| (key3_btn, "key", system3, user3), (use3_btn, "use", system3, user3), | |
| (not3_btn, "not", system3, user3), (que3_btn, "que", system3, user3), | |
| ]: | |
| _fn_map = {"key": key_f, "use": use_f, "not": not_f, "que": que_f} | |
| _fn = _fn_map[fn_type] | |
| _sb, _ub = sys_box, usr_box | |
| btn.click( | |
| lambda s, u, k, c, q, a, _f=_fn: _f(s, u, {"key": k, "que": q, "use": c, "not": a}[fn_type]), | |
| [_sb, _ub] + _common, [_sb, _ub] | |
| ) | |
| # ββ Export ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _handle_export(k, c, a, q, c1, c2, c3, qa, img): | |
| msg, path = export_joycaption_data(k, c, a, q, c1, c2, c3, qa, img) | |
| if path: | |
| return gr.update(value=msg, visible=True), gr.update(value=path, visible=True) | |
| return gr.update(value=msg, visible=True), gr.update(visible=False) | |
| export_btn.click( | |
| _handle_export, | |
| [keywords_input, custom_inst_input, avoid_input, question_input, | |
| out1, out2, out3, qa_output, image_input], | |
| [export_msg, export_file] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |