import json import base64 import io import requests from pathlib import Path from typing import Dict, Any, Optional, Tuple import gradio as gr from PIL import Image # Import prompt building functions from prompts.py from prompts import make_user_query, system_prompt, prompts_b # ==================== CONFIGURATION ==================== # API settings API_URL = "http://127.0.0.1:8000/v1/chat/completions" API_KEY = "not-needed" # Image settings MAX_PIXELS = 1.0 # Maximum resolution in megapixels (e.g., 4.0 = 4MP) # Request settings MAX_TOKENS = 4096 TEMPERATURE = 0.5 REQUEST_TIMEOUT = 5 # Reduced for connection check WORK_TIMEOUT = 300 # Captioning type options (from prompts_b in prompts.py) CAPTION_TYPES = list(prompts_b.keys()) DEFAULT_C_TYPE = CAPTION_TYPES[0] if CAPTION_TYPES else None if not DEFAULT_C_TYPE: raise RuntimeError("No caption types available in prompts_b!") # ==================== END CONFIGURATION ==================== def check_api_connection(api_url: str) -> Tuple[str, str]: """ Check API connection and return model info. Returns (status_message, model_name). """ try: # Try to get models endpoint base_url = api_url.rstrip('/').split('/v1/')[0] models_url = f"{base_url}/v1/models" response = requests.get(models_url, timeout=REQUEST_TIMEOUT) response.raise_for_status() result = response.json() if result and 'data' in result and len(result['data']) > 0: model_name = result['data'][0].get('id', 'Unknown') return "✅ Connected", model_name else: return "⚠️ Connected (no model info)", "Unknown" except requests.exceptions.ConnectionError: return "❌ Connection failed", "N/A" except requests.exceptions.Timeout: return "❌ Timeout", "N/A" except Exception as e: return f"❌ Error: {str(e)[:50]}", "N/A" def encode_image_base64(image: Image.Image, max_pixels: float = MAX_PIXELS) -> str: """Encode image to base64 string, resizing if necessary.""" img = image if img.mode != 'RGB': img = img.convert('RGB') # Check if resizing needed current_pixels = img.width * img.height max_pixels_count = max_pixels * 1_000_000 if current_pixels >= max_pixels_count: # Calculate new dimensions while preserving aspect ratio scale = (max_pixels_count / current_pixels) ** 0.5 new_width = int(img.width * scale) new_height = int(img.height * scale) # Resize with high quality img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) # No resize needed # Encode resized image to base64 buffer = io.BytesIO() img.save(buffer, format='JPEG', quality=100) return base64.b64encode(buffer.getvalue()).decode("utf-8") def call_caption_api(messages: list, api_url: str = API_URL, model_name: str = "toriigate-0.5") -> Optional[str]: """Call the captioning API.""" payload = { "model": model_name, "messages": messages, "max_tokens": MAX_TOKENS, "temperature": TEMPERATURE, "stream": False, "stop": ["<|im_end|>", "<|endoftext|>"] } headers = { "Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}" } try: response = requests.post( api_url, headers=headers, json=payload, timeout=WORK_TIMEOUT ) response.raise_for_status() result = response.json() content = result['choices'][0]['message']['content'] return content except requests.exceptions.RequestException as e: return f"API Error: {e}" except (KeyError, IndexError) as e: return f"Parse Error: {e}" def empty_template() -> Dict[str, Any]: """Return empty template for missing JSON data.""" return { "tags": [], "characters": [], "char_p_tags": {"chars": {}, "skins": {}}, "char_descr": {"chars": {}, "skins": {}} } def generate_caption( image: Image.Image, api_url: str, model_name: str, c_type: str, use_names: bool, add_tags: bool, add_char_list: bool, add_chars_tags: bool, add_chars_descr: bool, tags_text: str, characters_text: str, char1_name: str, char1_tags: str, char2_name: str, char2_tags: str, char3_name: str, char3_tags: str, char4_name: str, char4_tags: str, char5_name: str, char5_tags: str, char_descr1_name: str, char_descr1_text: str, char_descr2_name: str, char_descr2_text: str, char_descr3_name: str, char_descr3_text: str, char_descr4_name: str, char_descr4_text: str, char_descr5_name: str, char_descr5_text: str ) -> str: """Generate caption for a single image.""" if image is None: return "Please upload an image first." # Build item dict from inputs item = empty_template() # Parse tags if add_tags and tags_text.strip(): item["tags"] = [t.strip() for t in tags_text.split(',') if t.strip()] # Parse characters if add_char_list: item["characters"] = [c.strip() for c in characters_text.split(',') if c.strip()] # Auto-populate characters list from char tags/descriptions if not manually specified if add_chars_tags or add_chars_descr: auto_chars = [] if add_chars_tags: char_entries = [ char1_name, char2_name, char3_name, char4_name, char5_name ] for name in char_entries: if name and name.strip(): auto_chars.append(name.strip()) if add_chars_descr: descr_entries = [ char_descr1_name, char_descr2_name, char_descr3_name, char_descr4_name, char_descr5_name ] for name in descr_entries: if name and name.strip() and name.strip() not in auto_chars: auto_chars.append(name.strip()) # Only auto-populate if characters list is empty or not manually set if auto_chars and (not add_char_list or not item["characters"]): item["characters"] = auto_chars add_char_list = True # Parse character tags from structured inputs if add_chars_tags: chars_dict = {} char_entries = [ (char1_name, char1_tags), (char2_name, char2_tags), (char3_name, char3_tags), (char4_name, char4_tags), (char5_name, char5_tags) ] for name, tags_str in char_entries: if name is None: continue name = name.strip() if name: tags_list = [t.strip() for t in tags_str.split(',') if t.strip()] if tags_str and tags_str.strip() else [] chars_dict[name] = tags_list if chars_dict: item["char_p_tags"] = {"chars": chars_dict, "skins": {}} # Parse character descriptions from structured inputs if add_chars_descr: descr_dict = {} descr_entries = [ (char_descr1_name, char_descr1_text), (char_descr2_name, char_descr2_text), (char_descr3_name, char_descr3_text), (char_descr4_name, char_descr4_text), (char_descr5_name, char_descr5_text) ] for name, descr in descr_entries: if name is None or descr is None: continue name = name.strip() descr = descr.strip() if name and descr: descr_dict[name] = descr if descr_dict: item["char_descr"] = {"chars": descr_dict, "skins": {}} # Encode image image_data = encode_image_base64(image) # Prepare messages user_query = make_user_query( item, c_type=c_type, use_names=use_names, add_tags=add_tags, add_characters=add_char_list, add_char_tags=add_chars_tags, add_description=add_chars_descr, underscores_replace=False ) messages = [ { "role": "system", "content": [{"type": "text", "text": system_prompt}] }, { "role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}}, {"type": "text", "text": user_query} ] } ] # Call API return call_caption_api(messages, api_url, model_name) def create_ui(): """Create and return the Gradio interface.""" with gr.Blocks(title="ToriiGate Captioner", theme=gr.themes.Soft()) as app: gr.Markdown("# 🖼️ ToriiGate Captioner") # API URL row with status with gr.Row(): api_url_input = gr.Textbox( label="API URL", value=API_URL, interactive=True, scale=4 ) api_status = gr.Textbox( label="Status", value="⏳ Waiting for input...", interactive=False, scale=1 ) model_name_display = gr.Textbox( label="Model", value="N/A", interactive=False, scale=1 ) with gr.Row(): # Left column - Image input with gr.Column(scale=1): image_input = gr.Image( label="Upload Image", type="pil", height=400 ) gr.Markdown("### Configuration") # Caption type selector c_type = gr.Dropdown( choices=CAPTION_TYPES, value=DEFAULT_C_TYPE, label="Caption Type", interactive=True ) # Boolean options with conditional text inputs with gr.Group(): use_names = gr.Checkbox( value=True, label="Use Names (enable character names)" ) add_tags = gr.Checkbox( value=False, label="Add Tags" ) tags_text = gr.Textbox( label="Tags (comma-separated)", placeholder="e.g., 1girl, blue_hair, school_uniform", interactive=False ) add_char_list = gr.Checkbox( value=False, label="Add Character List" ) characters_text = gr.Textbox( label="Character Names (comma-separated)", placeholder="e.g., nishizono_mio, hoshimi_miyabi", interactive=False ) add_chars_tags = gr.Checkbox( value=False, label="Add Character Tags" ) with gr.Group(visible=False) as char_tags_group: gr.Markdown("**Add character names and their tags**") with gr.Accordion("Character 1", open=True): char1_name = gr.Textbox( label="Name", placeholder="e.g., albedo", interactive=True ) char1_tags = gr.Textbox( label="Tags (comma-separated)", placeholder="e.g., white_hair, green_eyes, horns", interactive=True ) with gr.Accordion("Character 2", open=False): char2_name = gr.Textbox( label="Name", placeholder="e.g., hoshimi_miyabi", interactive=True ) char2_tags = gr.Textbox( label="Tags (comma-separated)", placeholder="e.g., blue_hair, fox_ears", interactive=True ) with gr.Accordion("Character 3", open=False): char3_name = gr.Textbox( label="Name", placeholder="e.g., nishizono_mio", interactive=True ) char3_tags = gr.Textbox( label="Tags (comma-separated)", placeholder="e.g., brown_hair, glasses", interactive=True ) with gr.Accordion("Character 4", open=False): char4_name = gr.Textbox( label="Name", placeholder="e.g.", interactive=True ) char4_tags = gr.Textbox( label="Tags (comma-separated)", placeholder="e.g.", interactive=True ) with gr.Accordion("Character 5", open=False): char5_name = gr.Textbox( label="Name", placeholder="e.g.", interactive=True ) char5_tags = gr.Textbox( label="Tags (comma-separated)", placeholder="e.g.", interactive=True ) char_tags_clear_btn = gr.Button( "🗑️ Clear All", variant="secondary", size="sm" ) add_chars_descr = gr.Checkbox( value=False, label="Add Character Descriptions" ) with gr.Group(visible=False) as char_descr_group: gr.Markdown("**Add character descriptions**") with gr.Accordion("Character 1", open=True): char_descr1_name = gr.Textbox( label="Name", placeholder="e.g., albedo", interactive=True ) char_descr1_text = gr.Textbox( label="Description", placeholder="e.g., Albedo is a curvy woman with...", lines=3, interactive=True ) with gr.Accordion("Character 2", open=False): char_descr2_name = gr.Textbox( label="Name", placeholder="e.g., hoshimi_miyabi", interactive=True ) char_descr2_text = gr.Textbox( label="Description", placeholder="e.g., Miyabi is a calm and collected...", lines=3, interactive=True ) with gr.Accordion("Character 3", open=False): char_descr3_name = gr.Textbox( label="Name", placeholder="e.g., nishizono_mio", interactive=True ) char_descr3_text = gr.Textbox( label="Description", placeholder="e.g., Mio is a cheerful girl with...", lines=3, interactive=True ) with gr.Accordion("Character 4", open=False): char_descr4_name = gr.Textbox( label="Name", placeholder="e.g.", interactive=True ) char_descr4_text = gr.Textbox( label="Description", placeholder="e.g.", lines=3, interactive=True ) with gr.Accordion("Character 5", open=False): char_descr5_name = gr.Textbox( label="Name", placeholder="e.g.", interactive=True ) char_descr5_text = gr.Textbox( label="Description", placeholder="e.g.", lines=3, interactive=True ) char_descr_clear_btn = gr.Button( "🗑️ Clear All", variant="secondary", size="sm" ) generate_btn = gr.Button("🚀 Generate Caption", variant="primary", size="lg") # Right column - Output with gr.Column(scale=1): output_text = gr.Textbox( label="Caption Output", lines=20, max_lines=50, interactive=False ) # Toggle text inputs based on checkbox state def toggle_input(is_checked: bool, input_component): return gr.update(interactive=is_checked) add_tags.change( lambda x: toggle_input(x, tags_text), inputs=add_tags, outputs=tags_text ) add_char_list.change( lambda x: toggle_input(x, characters_text), inputs=add_char_list, outputs=characters_text ) add_chars_tags.change( fn=lambda x: gr.update(visible=x), inputs=add_chars_tags, outputs=char_tags_group ) add_chars_descr.change( fn=lambda x: gr.update(visible=x), inputs=add_chars_descr, outputs=char_descr_group ) # API URL change handler api_url_input.change( fn=check_api_connection, inputs=api_url_input, outputs=[api_status, model_name_display] ) # Wire up generate button generate_btn.click( fn=generate_caption, inputs=[ image_input, api_url_input, model_name_display, c_type, use_names, add_tags, add_char_list, add_chars_tags, add_chars_descr, tags_text, characters_text, char1_name, char1_tags, char2_name, char2_tags, char3_name, char3_tags, char4_name, char4_tags, char5_name, char5_tags, char_descr1_name, char_descr1_text, char_descr2_name, char_descr2_text, char_descr3_name, char_descr3_text, char_descr4_name, char_descr4_text, char_descr5_name, char_descr5_text ], outputs=output_text ) # Clear character tags button handler def clear_char_tags(): return "", "", "", "", "", "", "", "", "", "" char_tags_clear_btn.click( fn=clear_char_tags, inputs=[], outputs=[ char1_name, char1_tags, char2_name, char2_tags, char3_name, char3_tags, char4_name, char4_tags, char5_name, char5_tags ] ) # Clear character descriptions button handler def clear_char_descr(): return "", "", "", "", "", "", "", "", "", "" char_descr_clear_btn.click( fn=clear_char_descr, inputs=[], outputs=[ char_descr1_name, char_descr1_text, char_descr2_name, char_descr2_text, char_descr3_name, char_descr3_text, char_descr4_name, char_descr4_text, char_descr5_name, char_descr5_text ] ) return app if __name__ == "__main__": app = create_ui() app.launch(server_name="127.0.0.1", server_port=7860)