| import json |
| import base64 |
| import io |
| import requests |
| from pathlib import Path |
| from typing import Dict, Any, Optional, Tuple |
|
|
| import gradio as gr |
| from PIL import Image |
|
|
| |
| from prompts import make_user_query, system_prompt, prompts_b |
|
|
| |
|
|
| |
| API_URL = "http://127.0.0.1:8000/v1/chat/completions" |
| API_KEY = "not-needed" |
|
|
| |
| MAX_PIXELS = 1.0 |
|
|
| |
| MAX_TOKENS = 4096 |
| TEMPERATURE = 0.5 |
| REQUEST_TIMEOUT = 5 |
| WORK_TIMEOUT = 300 |
|
|
| |
| CAPTION_TYPES = list(prompts_b.keys()) |
| DEFAULT_C_TYPE = CAPTION_TYPES[0] if CAPTION_TYPES else None |
|
|
| if not DEFAULT_C_TYPE: |
| raise RuntimeError("No caption types available in prompts_b!") |
|
|
| |
|
|
|
|
| def check_api_connection(api_url: str) -> Tuple[str, str]: |
| """ |
| Check API connection and return model info. |
| Returns (status_message, model_name). |
| """ |
| try: |
| |
| base_url = api_url.rstrip('/').split('/v1/')[0] |
| models_url = f"{base_url}/v1/models" |
| |
| response = requests.get(models_url, timeout=REQUEST_TIMEOUT) |
| response.raise_for_status() |
| |
| result = response.json() |
| if result and 'data' in result and len(result['data']) > 0: |
| model_name = result['data'][0].get('id', 'Unknown') |
| return "✅ Connected", model_name |
| else: |
| return "⚠️ Connected (no model info)", "Unknown" |
| |
| except requests.exceptions.ConnectionError: |
| return "❌ Connection failed", "N/A" |
| except requests.exceptions.Timeout: |
| return "❌ Timeout", "N/A" |
| except Exception as e: |
| return f"❌ Error: {str(e)[:50]}", "N/A" |
|
|
|
|
| def encode_image_base64(image: Image.Image, max_pixels: float = MAX_PIXELS) -> str: |
| """Encode image to base64 string, resizing if necessary.""" |
| img = image |
| |
| if img.mode != 'RGB': |
| img = img.convert('RGB') |
| |
| |
| current_pixels = img.width * img.height |
| max_pixels_count = max_pixels * 1_000_000 |
| |
| if current_pixels >= max_pixels_count: |
| |
| scale = (max_pixels_count / current_pixels) ** 0.5 |
| new_width = int(img.width * scale) |
| new_height = int(img.height * scale) |
| |
| |
| img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) |
| |
| |
| |
| buffer = io.BytesIO() |
| img.save(buffer, format='JPEG', quality=100) |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
|
|
|
| def call_caption_api(messages: list, api_url: str = API_URL, model_name: str = "toriigate-0.5") -> Optional[str]: |
| """Call the captioning API.""" |
| payload = { |
| "model": model_name, |
| "messages": messages, |
| "max_tokens": MAX_TOKENS, |
| "temperature": TEMPERATURE, |
| "stream": False, |
| "stop": ["<|im_end|>", "<|endoftext|>"] |
| } |
|
|
| headers = { |
| "Content-Type": "application/json", |
| "Authorization": f"Bearer {API_KEY}" |
| } |
|
|
| try: |
| response = requests.post( |
| api_url, |
| headers=headers, |
| json=payload, |
| timeout=WORK_TIMEOUT |
| ) |
| response.raise_for_status() |
|
|
| result = response.json() |
| content = result['choices'][0]['message']['content'] |
| return content |
|
|
| except requests.exceptions.RequestException as e: |
| return f"API Error: {e}" |
| except (KeyError, IndexError) as e: |
| return f"Parse Error: {e}" |
|
|
|
|
| def empty_template() -> Dict[str, Any]: |
| """Return empty template for missing JSON data.""" |
| return { |
| "tags": [], |
| "characters": [], |
| "char_p_tags": {"chars": {}, "skins": {}}, |
| "char_descr": {"chars": {}, "skins": {}} |
| } |
|
|
|
|
| def generate_caption( |
| image: Image.Image, |
| api_url: str, |
| model_name: str, |
| c_type: str, |
| use_names: bool, |
| add_tags: bool, |
| add_char_list: bool, |
| add_chars_tags: bool, |
| add_chars_descr: bool, |
| tags_text: str, |
| characters_text: str, |
| char1_name: str, |
| char1_tags: str, |
| char2_name: str, |
| char2_tags: str, |
| char3_name: str, |
| char3_tags: str, |
| char4_name: str, |
| char4_tags: str, |
| char5_name: str, |
| char5_tags: str, |
| char_descr1_name: str, |
| char_descr1_text: str, |
| char_descr2_name: str, |
| char_descr2_text: str, |
| char_descr3_name: str, |
| char_descr3_text: str, |
| char_descr4_name: str, |
| char_descr4_text: str, |
| char_descr5_name: str, |
| char_descr5_text: str |
| ) -> str: |
| """Generate caption for a single image.""" |
| if image is None: |
| return "Please upload an image first." |
|
|
| |
| item = empty_template() |
|
|
| |
| if add_tags and tags_text.strip(): |
| item["tags"] = [t.strip() for t in tags_text.split(',') if t.strip()] |
|
|
| |
| if add_char_list: |
| item["characters"] = [c.strip() for c in characters_text.split(',') if c.strip()] |
| |
| |
| if add_chars_tags or add_chars_descr: |
| auto_chars = [] |
| |
| if add_chars_tags: |
| char_entries = [ |
| char1_name, char2_name, char3_name, char4_name, char5_name |
| ] |
| for name in char_entries: |
| if name and name.strip(): |
| auto_chars.append(name.strip()) |
| |
| if add_chars_descr: |
| descr_entries = [ |
| char_descr1_name, char_descr2_name, char_descr3_name, |
| char_descr4_name, char_descr5_name |
| ] |
| for name in descr_entries: |
| if name and name.strip() and name.strip() not in auto_chars: |
| auto_chars.append(name.strip()) |
| |
| |
| if auto_chars and (not add_char_list or not item["characters"]): |
| item["characters"] = auto_chars |
| add_char_list = True |
|
|
| |
| if add_chars_tags: |
| chars_dict = {} |
| char_entries = [ |
| (char1_name, char1_tags), |
| (char2_name, char2_tags), |
| (char3_name, char3_tags), |
| (char4_name, char4_tags), |
| (char5_name, char5_tags) |
| ] |
| for name, tags_str in char_entries: |
| if name is None: |
| continue |
| name = name.strip() |
| if name: |
| tags_list = [t.strip() for t in tags_str.split(',') if t.strip()] if tags_str and tags_str.strip() else [] |
| chars_dict[name] = tags_list |
| |
| if chars_dict: |
| item["char_p_tags"] = {"chars": chars_dict, "skins": {}} |
|
|
| |
| if add_chars_descr: |
| descr_dict = {} |
| descr_entries = [ |
| (char_descr1_name, char_descr1_text), |
| (char_descr2_name, char_descr2_text), |
| (char_descr3_name, char_descr3_text), |
| (char_descr4_name, char_descr4_text), |
| (char_descr5_name, char_descr5_text) |
| ] |
| for name, descr in descr_entries: |
| if name is None or descr is None: |
| continue |
| name = name.strip() |
| descr = descr.strip() |
| if name and descr: |
| descr_dict[name] = descr |
| |
| if descr_dict: |
| item["char_descr"] = {"chars": descr_dict, "skins": {}} |
|
|
| |
| image_data = encode_image_base64(image) |
|
|
| |
| user_query = make_user_query( |
| item, |
| c_type=c_type, |
| use_names=use_names, |
| add_tags=add_tags, |
| add_characters=add_char_list, |
| add_char_tags=add_chars_tags, |
| add_description=add_chars_descr, |
| underscores_replace=False |
| ) |
|
|
| messages = [ |
| { |
| "role": "system", |
| "content": [{"type": "text", "text": system_prompt}] |
| }, |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}}, |
| {"type": "text", "text": user_query} |
| ] |
| } |
| ] |
|
|
| |
| return call_caption_api(messages, api_url, model_name) |
|
|
|
|
| def create_ui(): |
| """Create and return the Gradio interface.""" |
|
|
| with gr.Blocks(title="ToriiGate Captioner", theme=gr.themes.Soft()) as app: |
| gr.Markdown("# 🖼️ ToriiGate Captioner") |
| |
| |
| with gr.Row(): |
| api_url_input = gr.Textbox( |
| label="API URL", |
| value=API_URL, |
| interactive=True, |
| scale=4 |
| ) |
| api_status = gr.Textbox( |
| label="Status", |
| value="⏳ Waiting for input...", |
| interactive=False, |
| scale=1 |
| ) |
| model_name_display = gr.Textbox( |
| label="Model", |
| value="N/A", |
| interactive=False, |
| scale=1 |
| ) |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| image_input = gr.Image( |
| label="Upload Image", |
| type="pil", |
| height=400 |
| ) |
| |
| gr.Markdown("### Configuration") |
| |
| |
| c_type = gr.Dropdown( |
| choices=CAPTION_TYPES, |
| value=DEFAULT_C_TYPE, |
| label="Caption Type", |
| interactive=True |
| ) |
| |
| |
| with gr.Group(): |
| use_names = gr.Checkbox( |
| value=True, |
| label="Use Names (enable character names)" |
| ) |
| |
| add_tags = gr.Checkbox( |
| value=False, |
| label="Add Tags" |
| ) |
| tags_text = gr.Textbox( |
| label="Tags (comma-separated)", |
| placeholder="e.g., 1girl, blue_hair, school_uniform", |
| interactive=False |
| ) |
| |
| add_char_list = gr.Checkbox( |
| value=False, |
| label="Add Character List" |
| ) |
| characters_text = gr.Textbox( |
| label="Character Names (comma-separated)", |
| placeholder="e.g., nishizono_mio, hoshimi_miyabi", |
| interactive=False |
| ) |
|
|
| add_chars_tags = gr.Checkbox( |
| value=False, |
| label="Add Character Tags" |
| ) |
| |
| with gr.Group(visible=False) as char_tags_group: |
| gr.Markdown("**Add character names and their tags**") |
| |
| with gr.Accordion("Character 1", open=True): |
| char1_name = gr.Textbox( |
| label="Name", |
| placeholder="e.g., albedo", |
| interactive=True |
| ) |
| char1_tags = gr.Textbox( |
| label="Tags (comma-separated)", |
| placeholder="e.g., white_hair, green_eyes, horns", |
| interactive=True |
| ) |
| |
| with gr.Accordion("Character 2", open=False): |
| char2_name = gr.Textbox( |
| label="Name", |
| placeholder="e.g., hoshimi_miyabi", |
| interactive=True |
| ) |
| char2_tags = gr.Textbox( |
| label="Tags (comma-separated)", |
| placeholder="e.g., blue_hair, fox_ears", |
| interactive=True |
| ) |
| |
| with gr.Accordion("Character 3", open=False): |
| char3_name = gr.Textbox( |
| label="Name", |
| placeholder="e.g., nishizono_mio", |
| interactive=True |
| ) |
| char3_tags = gr.Textbox( |
| label="Tags (comma-separated)", |
| placeholder="e.g., brown_hair, glasses", |
| interactive=True |
| ) |
| |
| with gr.Accordion("Character 4", open=False): |
| char4_name = gr.Textbox( |
| label="Name", |
| placeholder="e.g.", |
| interactive=True |
| ) |
| char4_tags = gr.Textbox( |
| label="Tags (comma-separated)", |
| placeholder="e.g.", |
| interactive=True |
| ) |
| |
| with gr.Accordion("Character 5", open=False): |
| char5_name = gr.Textbox( |
| label="Name", |
| placeholder="e.g.", |
| interactive=True |
| ) |
| char5_tags = gr.Textbox( |
| label="Tags (comma-separated)", |
| placeholder="e.g.", |
| interactive=True |
| ) |
| |
| char_tags_clear_btn = gr.Button( |
| "🗑️ Clear All", |
| variant="secondary", |
| size="sm" |
| ) |
|
|
| add_chars_descr = gr.Checkbox( |
| value=False, |
| label="Add Character Descriptions" |
| ) |
| |
| with gr.Group(visible=False) as char_descr_group: |
| gr.Markdown("**Add character descriptions**") |
| |
| with gr.Accordion("Character 1", open=True): |
| char_descr1_name = gr.Textbox( |
| label="Name", |
| placeholder="e.g., albedo", |
| interactive=True |
| ) |
| char_descr1_text = gr.Textbox( |
| label="Description", |
| placeholder="e.g., Albedo is a curvy woman with...", |
| lines=3, |
| interactive=True |
| ) |
| |
| with gr.Accordion("Character 2", open=False): |
| char_descr2_name = gr.Textbox( |
| label="Name", |
| placeholder="e.g., hoshimi_miyabi", |
| interactive=True |
| ) |
| char_descr2_text = gr.Textbox( |
| label="Description", |
| placeholder="e.g., Miyabi is a calm and collected...", |
| lines=3, |
| interactive=True |
| ) |
| |
| with gr.Accordion("Character 3", open=False): |
| char_descr3_name = gr.Textbox( |
| label="Name", |
| placeholder="e.g., nishizono_mio", |
| interactive=True |
| ) |
| char_descr3_text = gr.Textbox( |
| label="Description", |
| placeholder="e.g., Mio is a cheerful girl with...", |
| lines=3, |
| interactive=True |
| ) |
| |
| with gr.Accordion("Character 4", open=False): |
| char_descr4_name = gr.Textbox( |
| label="Name", |
| placeholder="e.g.", |
| interactive=True |
| ) |
| char_descr4_text = gr.Textbox( |
| label="Description", |
| placeholder="e.g.", |
| lines=3, |
| interactive=True |
| ) |
| |
| with gr.Accordion("Character 5", open=False): |
| char_descr5_name = gr.Textbox( |
| label="Name", |
| placeholder="e.g.", |
| interactive=True |
| ) |
| char_descr5_text = gr.Textbox( |
| label="Description", |
| placeholder="e.g.", |
| lines=3, |
| interactive=True |
| ) |
| |
| char_descr_clear_btn = gr.Button( |
| "🗑️ Clear All", |
| variant="secondary", |
| size="sm" |
| ) |
|
|
| generate_btn = gr.Button("🚀 Generate Caption", variant="primary", size="lg") |
| |
| |
| with gr.Column(scale=1): |
| output_text = gr.Textbox( |
| label="Caption Output", |
| lines=20, |
| max_lines=50, |
| interactive=False |
| ) |
| |
| |
| def toggle_input(is_checked: bool, input_component): |
| return gr.update(interactive=is_checked) |
|
|
| add_tags.change( |
| lambda x: toggle_input(x, tags_text), |
| inputs=add_tags, |
| outputs=tags_text |
| ) |
|
|
| add_char_list.change( |
| lambda x: toggle_input(x, characters_text), |
| inputs=add_char_list, |
| outputs=characters_text |
| ) |
|
|
| add_chars_tags.change( |
| fn=lambda x: gr.update(visible=x), |
| inputs=add_chars_tags, |
| outputs=char_tags_group |
| ) |
|
|
| add_chars_descr.change( |
| fn=lambda x: gr.update(visible=x), |
| inputs=add_chars_descr, |
| outputs=char_descr_group |
| ) |
|
|
| |
| api_url_input.change( |
| fn=check_api_connection, |
| inputs=api_url_input, |
| outputs=[api_status, model_name_display] |
| ) |
|
|
| |
| generate_btn.click( |
| fn=generate_caption, |
| inputs=[ |
| image_input, |
| api_url_input, |
| model_name_display, |
| c_type, |
| use_names, |
| add_tags, |
| add_char_list, |
| add_chars_tags, |
| add_chars_descr, |
| tags_text, |
| characters_text, |
| char1_name, |
| char1_tags, |
| char2_name, |
| char2_tags, |
| char3_name, |
| char3_tags, |
| char4_name, |
| char4_tags, |
| char5_name, |
| char5_tags, |
| char_descr1_name, |
| char_descr1_text, |
| char_descr2_name, |
| char_descr2_text, |
| char_descr3_name, |
| char_descr3_text, |
| char_descr4_name, |
| char_descr4_text, |
| char_descr5_name, |
| char_descr5_text |
| ], |
| outputs=output_text |
| ) |
|
|
| |
| def clear_char_tags(): |
| return "", "", "", "", "", "", "", "", "", "" |
| |
| char_tags_clear_btn.click( |
| fn=clear_char_tags, |
| inputs=[], |
| outputs=[ |
| char1_name, char1_tags, |
| char2_name, char2_tags, |
| char3_name, char3_tags, |
| char4_name, char4_tags, |
| char5_name, char5_tags |
| ] |
| ) |
|
|
| |
| def clear_char_descr(): |
| return "", "", "", "", "", "", "", "", "", "" |
| |
| char_descr_clear_btn.click( |
| fn=clear_char_descr, |
| inputs=[], |
| outputs=[ |
| char_descr1_name, char_descr1_text, |
| char_descr2_name, char_descr2_text, |
| char_descr3_name, char_descr3_text, |
| char_descr4_name, char_descr4_text, |
| char_descr5_name, char_descr5_text |
| ] |
| ) |
| |
| return app |
|
|
|
|
| if __name__ == "__main__": |
| app = create_ui() |
| app.launch(server_name="127.0.0.1", server_port=7860) |
|
|