ToriiGate-0.5 / scripts /gradio_interface.py
Minthy's picture
Add stop tokens to prevent repeated <|im_end|> in generation output (#2)
661b2bb
import json
import base64
import io
import requests
from pathlib import Path
from typing import Dict, Any, Optional, Tuple
import gradio as gr
from PIL import Image
# Import prompt building functions from prompts.py
from prompts import make_user_query, system_prompt, prompts_b
# ==================== CONFIGURATION ====================
# API settings
API_URL = "http://127.0.0.1:8000/v1/chat/completions"
API_KEY = "not-needed"
# Image settings
MAX_PIXELS = 1.0 # Maximum resolution in megapixels (e.g., 4.0 = 4MP)
# Request settings
MAX_TOKENS = 4096
TEMPERATURE = 0.5
REQUEST_TIMEOUT = 5 # Reduced for connection check
WORK_TIMEOUT = 300
# Captioning type options (from prompts_b in prompts.py)
CAPTION_TYPES = list(prompts_b.keys())
DEFAULT_C_TYPE = CAPTION_TYPES[0] if CAPTION_TYPES else None
if not DEFAULT_C_TYPE:
raise RuntimeError("No caption types available in prompts_b!")
# ==================== END CONFIGURATION ====================
def check_api_connection(api_url: str) -> Tuple[str, str]:
"""
Check API connection and return model info.
Returns (status_message, model_name).
"""
try:
# Try to get models endpoint
base_url = api_url.rstrip('/').split('/v1/')[0]
models_url = f"{base_url}/v1/models"
response = requests.get(models_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status()
result = response.json()
if result and 'data' in result and len(result['data']) > 0:
model_name = result['data'][0].get('id', 'Unknown')
return "✅ Connected", model_name
else:
return "⚠️ Connected (no model info)", "Unknown"
except requests.exceptions.ConnectionError:
return "❌ Connection failed", "N/A"
except requests.exceptions.Timeout:
return "❌ Timeout", "N/A"
except Exception as e:
return f"❌ Error: {str(e)[:50]}", "N/A"
def encode_image_base64(image: Image.Image, max_pixels: float = MAX_PIXELS) -> str:
"""Encode image to base64 string, resizing if necessary."""
img = image
if img.mode != 'RGB':
img = img.convert('RGB')
# Check if resizing needed
current_pixels = img.width * img.height
max_pixels_count = max_pixels * 1_000_000
if current_pixels >= max_pixels_count:
# Calculate new dimensions while preserving aspect ratio
scale = (max_pixels_count / current_pixels) ** 0.5
new_width = int(img.width * scale)
new_height = int(img.height * scale)
# Resize with high quality
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# No resize needed
# Encode resized image to base64
buffer = io.BytesIO()
img.save(buffer, format='JPEG', quality=100)
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def call_caption_api(messages: list, api_url: str = API_URL, model_name: str = "toriigate-0.5") -> Optional[str]:
"""Call the captioning API."""
payload = {
"model": model_name,
"messages": messages,
"max_tokens": MAX_TOKENS,
"temperature": TEMPERATURE,
"stream": False,
"stop": ["<|im_end|>", "<|endoftext|>"]
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
try:
response = requests.post(
api_url,
headers=headers,
json=payload,
timeout=WORK_TIMEOUT
)
response.raise_for_status()
result = response.json()
content = result['choices'][0]['message']['content']
return content
except requests.exceptions.RequestException as e:
return f"API Error: {e}"
except (KeyError, IndexError) as e:
return f"Parse Error: {e}"
def empty_template() -> Dict[str, Any]:
"""Return empty template for missing JSON data."""
return {
"tags": [],
"characters": [],
"char_p_tags": {"chars": {}, "skins": {}},
"char_descr": {"chars": {}, "skins": {}}
}
def generate_caption(
image: Image.Image,
api_url: str,
model_name: str,
c_type: str,
use_names: bool,
add_tags: bool,
add_char_list: bool,
add_chars_tags: bool,
add_chars_descr: bool,
tags_text: str,
characters_text: str,
char1_name: str,
char1_tags: str,
char2_name: str,
char2_tags: str,
char3_name: str,
char3_tags: str,
char4_name: str,
char4_tags: str,
char5_name: str,
char5_tags: str,
char_descr1_name: str,
char_descr1_text: str,
char_descr2_name: str,
char_descr2_text: str,
char_descr3_name: str,
char_descr3_text: str,
char_descr4_name: str,
char_descr4_text: str,
char_descr5_name: str,
char_descr5_text: str
) -> str:
"""Generate caption for a single image."""
if image is None:
return "Please upload an image first."
# Build item dict from inputs
item = empty_template()
# Parse tags
if add_tags and tags_text.strip():
item["tags"] = [t.strip() for t in tags_text.split(',') if t.strip()]
# Parse characters
if add_char_list:
item["characters"] = [c.strip() for c in characters_text.split(',') if c.strip()]
# Auto-populate characters list from char tags/descriptions if not manually specified
if add_chars_tags or add_chars_descr:
auto_chars = []
if add_chars_tags:
char_entries = [
char1_name, char2_name, char3_name, char4_name, char5_name
]
for name in char_entries:
if name and name.strip():
auto_chars.append(name.strip())
if add_chars_descr:
descr_entries = [
char_descr1_name, char_descr2_name, char_descr3_name,
char_descr4_name, char_descr5_name
]
for name in descr_entries:
if name and name.strip() and name.strip() not in auto_chars:
auto_chars.append(name.strip())
# Only auto-populate if characters list is empty or not manually set
if auto_chars and (not add_char_list or not item["characters"]):
item["characters"] = auto_chars
add_char_list = True
# Parse character tags from structured inputs
if add_chars_tags:
chars_dict = {}
char_entries = [
(char1_name, char1_tags),
(char2_name, char2_tags),
(char3_name, char3_tags),
(char4_name, char4_tags),
(char5_name, char5_tags)
]
for name, tags_str in char_entries:
if name is None:
continue
name = name.strip()
if name:
tags_list = [t.strip() for t in tags_str.split(',') if t.strip()] if tags_str and tags_str.strip() else []
chars_dict[name] = tags_list
if chars_dict:
item["char_p_tags"] = {"chars": chars_dict, "skins": {}}
# Parse character descriptions from structured inputs
if add_chars_descr:
descr_dict = {}
descr_entries = [
(char_descr1_name, char_descr1_text),
(char_descr2_name, char_descr2_text),
(char_descr3_name, char_descr3_text),
(char_descr4_name, char_descr4_text),
(char_descr5_name, char_descr5_text)
]
for name, descr in descr_entries:
if name is None or descr is None:
continue
name = name.strip()
descr = descr.strip()
if name and descr:
descr_dict[name] = descr
if descr_dict:
item["char_descr"] = {"chars": descr_dict, "skins": {}}
# Encode image
image_data = encode_image_base64(image)
# Prepare messages
user_query = make_user_query(
item,
c_type=c_type,
use_names=use_names,
add_tags=add_tags,
add_characters=add_char_list,
add_char_tags=add_chars_tags,
add_description=add_chars_descr,
underscores_replace=False
)
messages = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}]
},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}},
{"type": "text", "text": user_query}
]
}
]
# Call API
return call_caption_api(messages, api_url, model_name)
def create_ui():
"""Create and return the Gradio interface."""
with gr.Blocks(title="ToriiGate Captioner", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🖼️ ToriiGate Captioner")
# API URL row with status
with gr.Row():
api_url_input = gr.Textbox(
label="API URL",
value=API_URL,
interactive=True,
scale=4
)
api_status = gr.Textbox(
label="Status",
value="⏳ Waiting for input...",
interactive=False,
scale=1
)
model_name_display = gr.Textbox(
label="Model",
value="N/A",
interactive=False,
scale=1
)
with gr.Row():
# Left column - Image input
with gr.Column(scale=1):
image_input = gr.Image(
label="Upload Image",
type="pil",
height=400
)
gr.Markdown("### Configuration")
# Caption type selector
c_type = gr.Dropdown(
choices=CAPTION_TYPES,
value=DEFAULT_C_TYPE,
label="Caption Type",
interactive=True
)
# Boolean options with conditional text inputs
with gr.Group():
use_names = gr.Checkbox(
value=True,
label="Use Names (enable character names)"
)
add_tags = gr.Checkbox(
value=False,
label="Add Tags"
)
tags_text = gr.Textbox(
label="Tags (comma-separated)",
placeholder="e.g., 1girl, blue_hair, school_uniform",
interactive=False
)
add_char_list = gr.Checkbox(
value=False,
label="Add Character List"
)
characters_text = gr.Textbox(
label="Character Names (comma-separated)",
placeholder="e.g., nishizono_mio, hoshimi_miyabi",
interactive=False
)
add_chars_tags = gr.Checkbox(
value=False,
label="Add Character Tags"
)
with gr.Group(visible=False) as char_tags_group:
gr.Markdown("**Add character names and their tags**")
with gr.Accordion("Character 1", open=True):
char1_name = gr.Textbox(
label="Name",
placeholder="e.g., albedo",
interactive=True
)
char1_tags = gr.Textbox(
label="Tags (comma-separated)",
placeholder="e.g., white_hair, green_eyes, horns",
interactive=True
)
with gr.Accordion("Character 2", open=False):
char2_name = gr.Textbox(
label="Name",
placeholder="e.g., hoshimi_miyabi",
interactive=True
)
char2_tags = gr.Textbox(
label="Tags (comma-separated)",
placeholder="e.g., blue_hair, fox_ears",
interactive=True
)
with gr.Accordion("Character 3", open=False):
char3_name = gr.Textbox(
label="Name",
placeholder="e.g., nishizono_mio",
interactive=True
)
char3_tags = gr.Textbox(
label="Tags (comma-separated)",
placeholder="e.g., brown_hair, glasses",
interactive=True
)
with gr.Accordion("Character 4", open=False):
char4_name = gr.Textbox(
label="Name",
placeholder="e.g.",
interactive=True
)
char4_tags = gr.Textbox(
label="Tags (comma-separated)",
placeholder="e.g.",
interactive=True
)
with gr.Accordion("Character 5", open=False):
char5_name = gr.Textbox(
label="Name",
placeholder="e.g.",
interactive=True
)
char5_tags = gr.Textbox(
label="Tags (comma-separated)",
placeholder="e.g.",
interactive=True
)
char_tags_clear_btn = gr.Button(
"🗑️ Clear All",
variant="secondary",
size="sm"
)
add_chars_descr = gr.Checkbox(
value=False,
label="Add Character Descriptions"
)
with gr.Group(visible=False) as char_descr_group:
gr.Markdown("**Add character descriptions**")
with gr.Accordion("Character 1", open=True):
char_descr1_name = gr.Textbox(
label="Name",
placeholder="e.g., albedo",
interactive=True
)
char_descr1_text = gr.Textbox(
label="Description",
placeholder="e.g., Albedo is a curvy woman with...",
lines=3,
interactive=True
)
with gr.Accordion("Character 2", open=False):
char_descr2_name = gr.Textbox(
label="Name",
placeholder="e.g., hoshimi_miyabi",
interactive=True
)
char_descr2_text = gr.Textbox(
label="Description",
placeholder="e.g., Miyabi is a calm and collected...",
lines=3,
interactive=True
)
with gr.Accordion("Character 3", open=False):
char_descr3_name = gr.Textbox(
label="Name",
placeholder="e.g., nishizono_mio",
interactive=True
)
char_descr3_text = gr.Textbox(
label="Description",
placeholder="e.g., Mio is a cheerful girl with...",
lines=3,
interactive=True
)
with gr.Accordion("Character 4", open=False):
char_descr4_name = gr.Textbox(
label="Name",
placeholder="e.g.",
interactive=True
)
char_descr4_text = gr.Textbox(
label="Description",
placeholder="e.g.",
lines=3,
interactive=True
)
with gr.Accordion("Character 5", open=False):
char_descr5_name = gr.Textbox(
label="Name",
placeholder="e.g.",
interactive=True
)
char_descr5_text = gr.Textbox(
label="Description",
placeholder="e.g.",
lines=3,
interactive=True
)
char_descr_clear_btn = gr.Button(
"🗑️ Clear All",
variant="secondary",
size="sm"
)
generate_btn = gr.Button("🚀 Generate Caption", variant="primary", size="lg")
# Right column - Output
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Caption Output",
lines=20,
max_lines=50,
interactive=False
)
# Toggle text inputs based on checkbox state
def toggle_input(is_checked: bool, input_component):
return gr.update(interactive=is_checked)
add_tags.change(
lambda x: toggle_input(x, tags_text),
inputs=add_tags,
outputs=tags_text
)
add_char_list.change(
lambda x: toggle_input(x, characters_text),
inputs=add_char_list,
outputs=characters_text
)
add_chars_tags.change(
fn=lambda x: gr.update(visible=x),
inputs=add_chars_tags,
outputs=char_tags_group
)
add_chars_descr.change(
fn=lambda x: gr.update(visible=x),
inputs=add_chars_descr,
outputs=char_descr_group
)
# API URL change handler
api_url_input.change(
fn=check_api_connection,
inputs=api_url_input,
outputs=[api_status, model_name_display]
)
# Wire up generate button
generate_btn.click(
fn=generate_caption,
inputs=[
image_input,
api_url_input,
model_name_display,
c_type,
use_names,
add_tags,
add_char_list,
add_chars_tags,
add_chars_descr,
tags_text,
characters_text,
char1_name,
char1_tags,
char2_name,
char2_tags,
char3_name,
char3_tags,
char4_name,
char4_tags,
char5_name,
char5_tags,
char_descr1_name,
char_descr1_text,
char_descr2_name,
char_descr2_text,
char_descr3_name,
char_descr3_text,
char_descr4_name,
char_descr4_text,
char_descr5_name,
char_descr5_text
],
outputs=output_text
)
# Clear character tags button handler
def clear_char_tags():
return "", "", "", "", "", "", "", "", "", ""
char_tags_clear_btn.click(
fn=clear_char_tags,
inputs=[],
outputs=[
char1_name, char1_tags,
char2_name, char2_tags,
char3_name, char3_tags,
char4_name, char4_tags,
char5_name, char5_tags
]
)
# Clear character descriptions button handler
def clear_char_descr():
return "", "", "", "", "", "", "", "", "", ""
char_descr_clear_btn.click(
fn=clear_char_descr,
inputs=[],
outputs=[
char_descr1_name, char_descr1_text,
char_descr2_name, char_descr2_text,
char_descr3_name, char_descr3_text,
char_descr4_name, char_descr4_text,
char_descr5_name, char_descr5_text
]
)
return app
if __name__ == "__main__":
app = create_ui()
app.launch(server_name="127.0.0.1", server_port=7860)