r3gm's picture
Update app.py
1241456 verified
import json
import csv
import io
import base64
import gzip
import os
import numpy as np
import pandas as pd
from PIL import Image
from io import BytesIO
import onnxruntime as rt
import huggingface_hub
import gradio as gr
from huggingface_hub import hf_hub_download
CHARATER_REPO = "flagrantia/character_select_stand_alone_app"
CHARACTER_MD5 = "wai_character_md5_v160.csv"
CHARACTER_THUMBS = "wai_character_thumbs_v160.json"
for filename in ["classified_tags_danbooru_beta.csv"]: # classified_tags_danbooru_alpha.csv
hf_hub_download(repo_id="r3gm/classified_tags", repo_type="dataset", filename=filename, local_dir=".")
for filename in [CHARACTER_MD5, CHARACTER_THUMBS]:
hf_hub_download(repo_id=CHARATER_REPO, repo_type="dataset", filename=filename, local_dir=".")
# --- Global variables ---
_character_md5_map = {}
_character_thumbs_data = {}
_danbooru_tag_classifier_df = None # Global for the classifier dataframe
_preprocessed_allowed_tags_set = set() # Stores tags allowed by current category filter
_last_selected_tag_categories = None # To track if categories changed
# --- WD14 Tagger Globals ---
_wd14_predictor_instance = None
# Available models
WD14_MODELS = {
"WD-SwinV2-V3 (lite)": "SmilingWolf/wd-swinv2-tagger-v3",
"WD-ViT-L-V3": "SmilingWolf/wd-vit-large-tagger-v3",
}
_wd14_selected_model_repo = WD14_MODELS["WD-SwinV2-V3 (lite)"] # Default selected model
# Default thresholds and MCut enabled status (can be changed by Gradio inputs)
_wd14_general_threshold = 0.35
_wd14_character_threshold = 0.85
_wd14_general_mcut_enabled = False
_wd14_character_mcut_enabled = False
# --- Tagger model specific file names ---
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
kaomojis = [
"0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<",
"3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||",
]
# --- Configuration ---
DEFAULT_MAX_DISPLAY_RESULTS = 6
CLASSIFIED_TAGS_CSV = "classified_tags_danbooru_beta.csv"
# --- Utility Functions ---
def load_character_data_for_app(md5_csv_path=CHARACTER_MD5, thumbs_json_path=CHARACTER_THUMBS):
global _character_md5_map, _character_thumbs_data, _danbooru_tag_classifier_df
if not _character_md5_map: # Load only if not already loaded
try:
# Check if file exists in the current directory first, then try sibling directory if not found
if not os.path.exists(md5_csv_path) and os.path.exists(os.path.join(os.path.dirname(__file__), '..', md5_csv_path)):
md5_csv_path = os.path.join(os.path.dirname(__file__), '..', md5_csv_path)
with open(md5_csv_path, 'r', encoding='utf-8') as csvfile:
reader = csv.reader(csvfile)
for row in reader:
if row and len(row) >= 2:
original_name = row[0].strip()
md5_hash = row[1].strip()
_character_md5_map[original_name.lower()] = {'original_name': original_name, 'md5': md5_hash}
except FileNotFoundError:
print(f"Error: {md5_csv_path} not found. Please ensure it's in the same directory as the script or a sibling directory.")
_character_md5_map = {} # Clear to prevent partial data issues
except Exception as e:
print(f"Error loading MD5 data: {e}")
_character_md5_map = {}
if not _character_thumbs_data: # Load only if not already loaded
try:
# Check if file exists in the current directory first, then try sibling directory if not found
if not os.path.exists(thumbs_json_path) and os.path.exists(os.path.join(os.path.dirname(__file__), '..', thumbs_json_path)):
thumbs_json_path = os.path.join(os.path.dirname(__file__), '..', thumbs_json_path)
with open(thumbs_json_path, 'r', encoding='utf-8') as jsonfile:
_character_thumbs_data = json.load(jsonfile)
print(f"Loaded {len(_character_thumbs_data)} thumbnail entries from {thumbs_json_path}.")
except FileNotFoundError:
print(f"Error: {thumbs_json_path} not found. Please ensure it's in the same directory as the script or a sibling directory.")
_character_thumbs_data = {} # Clear to prevent partial data issues
except Exception as e:
print(f"Error loading thumbnail data: {e}")
_character_thumbs_data = {}
# Load tag classifier data - assuming it always exists as per requirement
try:
classifier_csv_path = CLASSIFIED_TAGS_CSV
if not os.path.exists(classifier_csv_path) and os.path.exists(os.path.join(os.path.dirname(__file__), '..', classifier_csv_path)):
classifier_csv_path = os.path.join(os.path.dirname(__file__), '..', classifier_csv_path)
_danbooru_tag_classifier_df = pd.read_csv(classifier_csv_path, index_col='name')
# Drop 'tag_id' if it exists and is not needed for filtering
if 'tag_id' in _danbooru_tag_classifier_df.columns:
_danbooru_tag_classifier_df = _danbooru_tag_classifier_df.drop(columns=['tag_id'])
except Exception as e:
print(f"CRITICAL ERROR: Failed to load {CLASSIFIED_TAGS_CSV}. Tag filtering by classification will be unavailable and likely cause errors.")
_danbooru_tag_classifier_df = pd.DataFrame() # Ensure it's an empty DataFrame to avoid further crashes
def base64_to_pil_image(base64_data):
try:
compressed_data = base64.b64decode(base64_data)
webp_data = gzip.decompress(compressed_data)
image = Image.open(BytesIO(webp_data)).convert("RGBA")
return image
except Exception as e:
print(f"Error decoding base64 image: {e}")
return None
def escape_parentheses_for_prompt(text: str) -> str:
"""Escapes parentheses for use in stable diffusion prompts."""
return text.replace('(', r'\(').replace(')', r'\)')
def search_character_data_partial_for_app(character_name_partial, max_results):
results = []
search_query_lower = character_name_partial.lower()
if not search_query_lower:
return []
count = 0
for lower_name, char_info in _character_md5_map.items():
if search_query_lower in lower_name:
md5_hash = char_info['md5']
if md5_hash in _character_thumbs_data:
pil_image = base64_to_pil_image(_character_thumbs_data[md5_hash])
if pil_image:
results.append({
'name': char_info['original_name'],
'image': pil_image,
'md5': md5_hash # Keep md5 for potential future use or debugging
})
count += 1
if count >= max_results:
break
return results
# --- WD14 Tagger Specific Functions ---
def load_labels(dataframe) -> list[str]:
name_series = dataframe["name"]
name_series = name_series.map(
lambda x: x.replace("_", " ") if x not in kaomojis else x
)
tag_names = name_series.tolist()
rating_indexes = list(np.where(dataframe["category"] == 9)[0])
general_indexes = list(np.where(dataframe["category"] == 0)[0])
character_indexes = list(np.where(dataframe["category"] == 4)[0])
return tag_names, rating_indexes, general_indexes, character_indexes
def mcut_threshold(probs):
sorted_probs = probs[probs.argsort()[::-1]]
difs = sorted_probs[:-1] - sorted_probs[1:]
t = difs.argmax()
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
return thresh
class Predictor:
def __init__(self, model_repo_default: str): # Initialize with a default model repo
self.model_target_size = None
self.last_loaded_repo = None
self.model = None
self.tag_names = None
self.rating_indexes = None
self.general_indexes = None
self.character_indexes = None
self.load_model(model_repo_default) # Load the default model at init
def download_model_files(self, model_repo):
csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME)
model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME)
return csv_path, model_path
def load_model(self, model_repo):
if model_repo == self.last_loaded_repo and self.model is not None:
return
csv_path, model_path = self.download_model_files(model_repo)
tags_df = pd.read_csv(csv_path)
self.tag_names, self.rating_indexes, self.general_indexes, self.character_indexes = load_labels(tags_df)
# Clear previous model if any
if hasattr(self, 'model') and self.model is not None:
del self.model
self.model = rt.InferenceSession(model_path)
_, height, width, _ = self.model.get_inputs()[0].shape
self.model_target_size = height
print(f"WD14 model loaded successfully from {model_repo}")
self.last_loaded_repo = model_repo
def prepare_image(self, image: Image.Image): # Explicitly type-hint as PIL.Image.Image
"""
Prepares a PIL Image for the WD14 tagger model.
Resizes and converts the image to the model's expected input format.
"""
if not isinstance(image, Image.Image):
raise ValueError("Input to prepare_image must be a PIL.Image.Image object.")
target_size = self.model_target_size
# Convert to RGBA first to handle potential alpha channel correctly for white background
image = image.convert("RGBA")
canvas = Image.new("RGBA", image.size, (255, 255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB") # Now convert to RGB after blending with white background
image_shape = image.size
max_dim = max(image_shape)
pad_left = (max_dim - image_shape[0]) // 2
pad_top = (max_dim - image_shape[1]) // 2
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
padded_image.paste(image, (pad_left, pad_top))
if max_dim != target_size:
padded_image = padded_image.resize(
(target_size, target_size),
Image.Resampling.BICUBIC, # Use Image.Resampling
)
image_array = np.asarray(padded_image, dtype=np.float32)
# The model expects BGR, so reverse the channels
image_array = image_array[:, :, ::-1]
return np.expand_dims(image_array, axis=0) # Add batch dimension
def predict(
self,
image: Image.Image,
model_repo: str,
general_thresh: float,
general_mcut_enabled: bool,
character_thresh: float,
character_mcut_enabled: bool,
):
# Load model if it's different from the currently loaded one
self.load_model(model_repo) # This will skip if already loaded
image_input = self.prepare_image(image)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
preds = self.model.run([label_name], {input_name: image_input})[0]
labels = list(zip(self.tag_names, preds[0].astype(float)))
general_names = [labels[i] for i in self.general_indexes]
if general_mcut_enabled:
general_probs = np.array([x[1] for x in general_names])
if len(general_probs) > 1:
general_thresh = mcut_threshold(general_probs)
general_res = [x for x in general_names if x[1] > general_thresh]
general_res_dict = dict(general_res)
character_names = [labels[i] for i in self.character_indexes]
if character_mcut_enabled:
character_probs = np.array([x[1] for x in character_names])
if len(character_probs) > 1:
character_thresh = mcut_threshold(character_probs)
character_thresh = max(0.15, character_thresh)
character_res = [x for x in character_names if x[1] > character_thresh]
character_res_dict = dict(character_res)
# print(general_res_dict)
# Prepare tag names for prompt (without confidence, no escaping yet)
# general_tag_names = [x[0].replace('_', ' ') for x in sorted(general_res_dict.items(), key=lambda x: x[1], reverse=True)]
general_tag_names = [
x[0] if x[0] in kaomojis else x[0].replace('_', ' ')
for x in sorted(general_res_dict.items(), key=lambda x: x[1], reverse=True)
]
character_tag_names = [x[0].replace('_', ' ') for x in sorted(character_res_dict.items(), key=lambda x: x[1], reverse=True)]
return general_tag_names, character_tag_names
# --- Gradio Interface Functions ---
# Store search results temporarily for dropdown/gallery sync
_last_search_results = []
def search_characters_gradio(character_name_partial, max_results):
global _last_search_results
load_character_data_for_app() # Ensure data is loaded
found_chars = search_character_data_partial_for_app(character_name_partial, max_results)
_last_search_results = found_chars # Store results
if not found_chars:
# Clear selected character states when no results are found
return (
gr.Gallery(value=[], selected_index=None), # Clear gallery
gr.Dropdown(choices=[], interactive=False, label="Select Character", value=None), # Clear dropdown value
"", # Clear selected_char_original_name
None, # Clear selected_pil_image_from_search
"", # Clear selected_char_md5_hash
"No characters found." # Message
)
gallery_images_with_names = []
dropdown_options = []
for i, char_result in enumerate(found_chars):
gallery_images_with_names.append((char_result['image'], char_result['name']))
dropdown_options.append((char_result['name'], i))
# Automatically select the first character if results are found
first_char_name = found_chars[0]['name']
first_char_image = found_chars[0]['image']
first_char_md5 = found_chars[0]['md5']
return (
gr.Gallery(value=gallery_images_with_names, selected_index=0),
gr.Dropdown(choices=dropdown_options, interactive=True, label="Select Character", value=0),
first_char_name,
first_char_image,
first_char_md5,
"" # Clear status message
)
def get_selected_character_info_by_index(selected_index):
# This function retrieves info based on the index from the dropdown or gallery
if _last_search_results and selected_index is not None and 0 <= selected_index < len(_last_search_results):
selected_char = _last_search_results[selected_index]
return (
selected_char['name'],
selected_char['image'],
selected_char['md5'],
# Ensure gallery selection is visually updated
gr.Gallery(value=[(char['image'], char['name']) for char in _last_search_results], selected_index=selected_index)
)
# If no valid selection, clear everything
return "", None, "", gr.Gallery(value=[], selected_index=None)
def update_dropdown_from_gallery(evt: gr.SelectData):
# This function is called when an image in the gallery is clicked
# evt.index gives the index of the clicked image
if evt.index is not None:
return evt.index
return None # Return None to deselect dropdown if an empty area is clicked (though gallery itself might prevent this)
def update_allowed_tags_set(selected_categories: list):
"""
Pre-processes the allowed tags set based on selected categories.
This function is called when the tag category filter changes.
"""
global _preprocessed_allowed_tags_set, _last_selected_tag_categories
if _danbooru_tag_classifier_df.empty or not selected_categories:
_preprocessed_allowed_tags_set = set()
_last_selected_tag_categories = selected_categories
return
# Check if categories have actually changed
if selected_categories == _last_selected_tag_categories:
return
# print(f"Updating allowed tags set for categories: {selected_categories}")
allowed_tags_set = set()
for category in selected_categories:
if category in _danbooru_tag_classifier_df.columns:
# Get tags where the category column has a '1'
# Convert to lower case for case-insensitive matching
tags_in_category = _danbooru_tag_classifier_df[_danbooru_tag_classifier_df[category] == 1].index.tolist()
# allowed_tags_set.update(tag.replace('_', ' ').lower() for tag in tags_in_category)
allowed_tags_set.update(
tag.lower() if tag in kaomojis else tag.replace('_', ' ').lower()
for tag in tags_in_category
)
_preprocessed_allowed_tags_set = allowed_tags_set
_last_selected_tag_categories = selected_categories
# print(f"Allowed tags set updated with {len(_preprocessed_allowed_tags_set)} tags.")
# Unified prompt generation logic, now taking image and char_name explicitly
def _generate_prompt_logic(
image_to_tag: Image.Image,
char_name_from_source: str, # This is the char name explicitly chosen from DB search, if applicable
wd14_model_choice: str,
general_thresh: float,
general_mcut_enabled: bool,
character_thresh: float,
character_mcut_enabled: bool,
banned_words_str: str,
selected_tag_categories: list,
preserve_char_name_and_tags: bool, # Now covers char name AND detected character tags
):
if image_to_tag is None:
return "Error: No image provided for tagging.", "No prompt generated."
global _wd14_predictor_instance
if _wd14_predictor_instance is None:
# Initialize with the currently selected model from the dropdown
_wd14_predictor_instance = Predictor(WD14_MODELS[wd14_model_choice])
try:
general_tags, character_tags = _wd14_predictor_instance.predict(
image=image_to_tag,
model_repo=WD14_MODELS[wd14_model_choice], # Pass the selected model repo
general_thresh=general_thresh,
general_mcut_enabled=general_mcut_enabled,
character_thresh=character_thresh,
character_mcut_enabled=character_mcut_enabled,
)
except Exception as e:
return f"Error during tag generation: {e}", "Error generating prompt."
# Process banned words
banned_words_list = [word.strip().lower() for word in banned_words_str.split(',') if word.strip()]
prompt_parts = []
tags_to_filter = [] # Tags that will be subject to category filtering
# 1. Handle character name from database search (if applicable)
if char_name_from_source:
char_name_lower = char_name_from_source.lower()
if char_name_lower not in banned_words_list:
if preserve_char_name_and_tags:
prompt_parts.append(char_name_from_source)
else:
# If not preserving, it will be treated like any other tag for filtering
tags_to_filter.append(char_name_from_source)
# 2. Handle WD14 detected character tags
for tag in character_tags:
tag_lower = tag.lower()
if tag_lower not in banned_words_list:
# If preserving, add to prompt_parts directly
if preserve_char_name_and_tags:
# Ensure it's not a duplicate of char_name_from_source if that was already added
if not (char_name_from_source and tag_lower == char_name_from_source.lower()):
prompt_parts.append(tag)
else:
tags_to_filter.append(tag)
# 3. Handle WD14 detected general tags
for tag in general_tags:
tag_lower = tag.lower()
if tag_lower not in banned_words_list:
tags_to_filter.append(tag)
# --- TAG CLASSIFICATION FILTERING for tags_to_filter ---
filtered_categorized_tags = []
# Ensure allowed tags set is up-to-date with current category selection
update_allowed_tags_set(selected_tag_categories)
if _preprocessed_allowed_tags_set: # If there are active category filters
for tag in tags_to_filter:
if tag.lower() in _preprocessed_allowed_tags_set:
filtered_categorized_tags.append(tag)
else:
# If no categories are selected or classifier data is not active, use all raw tags
filtered_categorized_tags = tags_to_filter
# --- END TAG CLASSIFICATION FILTERING ---
prompt_parts.extend(filtered_categorized_tags)
# Ensure uniqueness and order (optional, but good practice for prompts)
final_prompt_list = []
seen_lower = set() # Use a set of lowercased tags for uniqueness check
for item in prompt_parts:
item_lower = item.lower()
if item_lower not in seen_lower:
final_prompt_list.append(item)
seen_lower.add(item_lower)
final_prompt = ", ".join(final_prompt_list)
final_prompt_escaped = escape_parentheses_for_prompt(final_prompt)
return "Prompt generated successfully!", final_prompt_escaped
# Wrapper function for the "Search Character Database" tab
def generate_prompt_from_search_tab(
character_original_name: str,
selected_pil_image_from_search: Image.Image,
wd14_model_choice: str, # Added model choice
general_thresh: float,
general_mcut_enabled: bool,
character_thresh: float,
character_mcut_enabled: bool,
banned_words_str: str,
selected_tag_categories: list,
preserve_char_name_and_tags: bool, # Now covers char name AND detected character tags
):
# For the search tab, the image to tag is the selected character image
image_to_tag = selected_pil_image_from_search
# The character name comes from the search selection
char_name_from_source = character_original_name
return _generate_prompt_logic(
image_to_tag=image_to_tag,
char_name_from_source=char_name_from_source,
wd14_model_choice=wd14_model_choice, # Pass model choice
general_thresh=general_thresh,
general_mcut_enabled=general_mcut_enabled,
character_thresh=character_thresh,
character_mcut_enabled=character_mcut_enabled,
banned_words_str=banned_words_str,
selected_tag_categories=selected_tag_categories,
preserve_char_name_and_tags=preserve_char_name_and_tags,
)
# Wrapper function for the "Upload Your Own Image" tab
def generate_prompt_from_upload_tab(
input_image_upload: Image.Image,
wd14_model_choice: str, # Added model choice
general_thresh: float,
general_mcut_enabled: bool,
character_thresh: float,
character_mcut_enabled: bool,
banned_words_str: str,
selected_tag_categories: list,
preserve_char_name_and_tags: bool, # Still relevant for detected char tags
):
# For the upload tab, the image to tag is the uploaded image
image_to_tag = input_image_upload
# There is no explicit character name from a database search for the upload tab
char_name_from_source = ""
return _generate_prompt_logic(
image_to_tag=image_to_tag,
char_name_from_source=char_name_from_source,
wd14_model_choice=wd14_model_choice, # Pass model choice
general_thresh=general_thresh,
general_mcut_enabled=general_mcut_enabled,
character_thresh=character_thresh,
character_mcut_enabled=character_mcut_enabled,
banned_words_str=banned_words_str,
selected_tag_categories=selected_tag_categories,
preserve_char_name_and_tags=preserve_char_name_and_tags,
)
# Initialize the predictor once globally with the default model
_wd14_predictor_instance = Predictor(_wd14_selected_model_repo)
# --- Gradio Interface Layout ---
# Load data once at the start of the script
load_character_data_for_app()
# Get tag classification categories for the dropdown, if available
tag_classification_categories = []
if _danbooru_tag_classifier_df is not None and not _danbooru_tag_classifier_df.empty:
tag_classification_categories = [col for col in _danbooru_tag_classifier_df.columns if col not in ['tag_id']]
else:
print("tag classifier data not loaded or empty. Tag category filtering will be disabled.")
CSS = """
#gallery {
height: 300px;
max-height: 520px;
margin-left: 0; /* no left margin */
flex-grow: 1;
}
"""
with gr.Blocks(title="Character Prompt Generator", css=CSS) as demo:
gr.Markdown(
"""
# Character Prompt Generator
Generate prompts for character images either by searching the database or uploading your own!
"""
)
# --- Shared Tagger Settings (Moved to the top for clarity) ---
gr.Markdown("#### ⚙️ Tagger Settings (Apply to both methods)")
with gr.Accordion("Adjust Tagging Parameters", open=False):
# NEW: Model Selection Dropdown
wd14_model_dropdown = gr.Dropdown(
label="WD14 Tagger Model",
choices=list(WD14_MODELS.keys()),
value=list(WD14_MODELS.keys())[0],
interactive=True,
info="Select which WD14 tagger model to use. Different models may yield slightly different tags."
)
with gr.Row():
general_threshold_slider = gr.Slider(
minimum=0.01, maximum=0.99, value=_wd14_general_threshold, step=0.01,
label="General Tags Threshold", interactive=True,
info="Confidence score required for general tags to be included."
)
general_mcut_checkbox = gr.Checkbox(
value=_wd14_general_mcut_enabled, label="Use MCut for General Tags", interactive=True,
info="MCut dynamically sets the threshold based on tag score distribution."
)
with gr.Row():
character_threshold_slider = gr.Slider(
minimum=0.01, maximum=0.99, value=_wd14_character_threshold, step=0.01,
label="Character Tags Threshold", interactive=True,
info="Confidence score required for character tags to be included."
)
character_mcut_checkbox = gr.Checkbox(
value=_wd14_character_mcut_enabled, label="Use MCut for Character Tags", interactive=True,
info="MCut dynamically sets the threshold based on tag score distribution."
)
banned_words_input = gr.Textbox(
label="Exclude Tags (comma-separated, case-insensitive)",
value="simple background, lowres, text",
placeholder="e.g., text, watermark",
interactive=True,
info="Tags listed here will be removed from the final prompt."
)
tag_category_filter = gr.Dropdown(
label="Filter General Tags by Category",
choices=tag_classification_categories,
multiselect=True,
interactive=bool(tag_classification_categories),
value=["body", "hair", "eyes_face", "species_traits", "head_accessories", "subjects_relationship"] if tag_classification_categories else [],
info="Only general tags belonging to selected categories will be included. Character tags are not affected by this filter unless 'Preserve Characters' is unchecked."
)
preserve_char_name_and_tags_checkbox = gr.Checkbox(
label="Preserve Character-Specific Tags (Bypass Category Filter)",
value=True,
interactive=True,
info="If checked, the selected character's name (from search) and any detected character tags will *always* be included (unless banned), bypassing the 'Filter Tags by Category' setting. General tags are still filtered."
)
# Hidden states to pass character data to prompt generation (keep these global-ish)
selected_pil_image_from_search = gr.State(None)
selected_char_original_name = gr.State("")
selected_char_md5_hash = gr.State("")
with gr.Tabs():
with gr.TabItem("1. Search Character Database"):
gr.Markdown("### 🔍 Find a Character")
with gr.Row():
with gr.Column(scale=1):
search_input = gr.Textbox(
label="Enter Character Name",
placeholder="e.g., Hatsune, zhongli",
interactive=True
)
max_results_slider = gr.Slider(
minimum=1,
maximum=50,
value=DEFAULT_MAX_DISPLAY_RESULTS,
step=1,
label="Max Search Results to Display",
interactive=True
)
search_output_message = gr.Markdown("Enter a character name to search.")
with gr.Column(scale=2):
character_gallery = gr.Gallery(
label="Matching Characters (Click to select)",
show_label=True,
elem_id="gallery",
columns=3,
rows=15,
object_fit="contain",
# height="auto",
allow_preview=False,
)
gr.Markdown("---")
gr.Markdown("### ➡️ Selected Character")
selected_character_dropdown = gr.Dropdown(
label="Selected Character",
choices=[],
interactive=False,
info="The character chosen from the search results above."
)
generate_prompt_button_search = gr.Button("Generate Prompt for Selected Character", variant="primary")
with gr.TabItem("2. Upload Your Own Image"):
gr.Markdown("### ⬆️ Upload an Image to Tag")
input_image_upload = gr.Image(
label="Upload Image (JPG, PNG, WEBP)",
type="pil",
height=200,
interactive=True,
# info="This image will be sent to the tagger for prompt generation."
)
generate_prompt_button_upload = gr.Button("Generate Prompt for Uploaded Image", variant="primary")
gr.Markdown("---")
gr.Markdown("#### ✨ Generated Prompt")
# Outputs are now consolidated here, after the inputs and buttons
generate_prompt_status = gr.Markdown(
"",
elem_id="shared_prompt_status"
)
prompt_output = gr.Textbox(
label="",
placeholder="Your generated prompt will appear here...",
lines=5,
interactive=True,
show_copy_button=True,
elem_id="shared_prompt_output"
)
# --- Event Handlers ---
# Search Character Database Tab Interactions
search_input.change(
fn=search_characters_gradio,
inputs=[search_input, max_results_slider],
outputs=[character_gallery, selected_character_dropdown, selected_char_original_name, selected_pil_image_from_search, selected_char_md5_hash, search_output_message]
)
max_results_slider.change(
fn=search_characters_gradio,
inputs=[search_input, max_results_slider],
outputs=[character_gallery, selected_character_dropdown, selected_char_original_name, selected_pil_image_from_search, selected_char_md5_hash, search_output_message]
)
selected_character_dropdown.change(
fn=get_selected_character_info_by_index,
inputs=[selected_character_dropdown],
outputs=[selected_char_original_name, selected_pil_image_from_search, selected_char_md5_hash, character_gallery]
)
character_gallery.select(
fn=update_dropdown_from_gallery,
inputs=None,
outputs=[selected_character_dropdown]
)
# Tag Category Filter (Shared Setting)
tag_category_filter.change(
fn=update_allowed_tags_set,
inputs=[tag_category_filter],
outputs=None, # This function only updates a global, doesn't return a Gradio component
api_name=False # Don't expose this helper function via API
)
# Also call it once at startup with the initial value to populate the global set
demo.load(
fn=lambda: update_allowed_tags_set(tag_category_filter.value),
inputs=[],
outputs=[]
)
# Generate Prompt Button for "Search Character Database" tab
generate_prompt_button_search.click(
fn=generate_prompt_from_search_tab,
inputs=[
selected_char_original_name,
selected_pil_image_from_search,
wd14_model_dropdown,
general_threshold_slider,
general_mcut_checkbox,
character_threshold_slider,
character_mcut_checkbox,
banned_words_input,
tag_category_filter,
preserve_char_name_and_tags_checkbox
],
outputs=[generate_prompt_status, prompt_output]
)
# Generate Prompt Button for "Upload Your Own Image" tab
generate_prompt_button_upload.click(
fn=generate_prompt_from_upload_tab,
inputs=[
input_image_upload,
wd14_model_dropdown,
general_threshold_slider,
general_mcut_checkbox,
character_threshold_slider,
character_mcut_checkbox,
banned_words_input,
tag_category_filter,
preserve_char_name_and_tags_checkbox
],
outputs=[generate_prompt_status, prompt_output]
)
demo.launch(debug=True, show_error=True)