Spaces:
Running
Running
| import json | |
| import csv | |
| import io | |
| import base64 | |
| import gzip | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| from io import BytesIO | |
| import onnxruntime as rt | |
| import huggingface_hub | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| CHARATER_REPO = "flagrantia/character_select_stand_alone_app" | |
| CHARACTER_MD5 = "wai_character_md5_v160.csv" | |
| CHARACTER_THUMBS = "wai_character_thumbs_v160.json" | |
| for filename in ["classified_tags_danbooru_beta.csv"]: # classified_tags_danbooru_alpha.csv | |
| hf_hub_download(repo_id="r3gm/classified_tags", repo_type="dataset", filename=filename, local_dir=".") | |
| for filename in [CHARACTER_MD5, CHARACTER_THUMBS]: | |
| hf_hub_download(repo_id=CHARATER_REPO, repo_type="dataset", filename=filename, local_dir=".") | |
| # --- Global variables --- | |
| _character_md5_map = {} | |
| _character_thumbs_data = {} | |
| _danbooru_tag_classifier_df = None # Global for the classifier dataframe | |
| _preprocessed_allowed_tags_set = set() # Stores tags allowed by current category filter | |
| _last_selected_tag_categories = None # To track if categories changed | |
| # --- WD14 Tagger Globals --- | |
| _wd14_predictor_instance = None | |
| # Available models | |
| WD14_MODELS = { | |
| "WD-SwinV2-V3 (lite)": "SmilingWolf/wd-swinv2-tagger-v3", | |
| "WD-ViT-L-V3": "SmilingWolf/wd-vit-large-tagger-v3", | |
| } | |
| _wd14_selected_model_repo = WD14_MODELS["WD-SwinV2-V3 (lite)"] # Default selected model | |
| # Default thresholds and MCut enabled status (can be changed by Gradio inputs) | |
| _wd14_general_threshold = 0.35 | |
| _wd14_character_threshold = 0.85 | |
| _wd14_general_mcut_enabled = False | |
| _wd14_character_mcut_enabled = False | |
| # --- Tagger model specific file names --- | |
| MODEL_FILENAME = "model.onnx" | |
| LABEL_FILENAME = "selected_tags.csv" | |
| kaomojis = [ | |
| "0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<", | |
| "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", | |
| ] | |
| # --- Configuration --- | |
| DEFAULT_MAX_DISPLAY_RESULTS = 6 | |
| CLASSIFIED_TAGS_CSV = "classified_tags_danbooru_beta.csv" | |
| # --- Utility Functions --- | |
| def load_character_data_for_app(md5_csv_path=CHARACTER_MD5, thumbs_json_path=CHARACTER_THUMBS): | |
| global _character_md5_map, _character_thumbs_data, _danbooru_tag_classifier_df | |
| if not _character_md5_map: # Load only if not already loaded | |
| try: | |
| # Check if file exists in the current directory first, then try sibling directory if not found | |
| if not os.path.exists(md5_csv_path) and os.path.exists(os.path.join(os.path.dirname(__file__), '..', md5_csv_path)): | |
| md5_csv_path = os.path.join(os.path.dirname(__file__), '..', md5_csv_path) | |
| with open(md5_csv_path, 'r', encoding='utf-8') as csvfile: | |
| reader = csv.reader(csvfile) | |
| for row in reader: | |
| if row and len(row) >= 2: | |
| original_name = row[0].strip() | |
| md5_hash = row[1].strip() | |
| _character_md5_map[original_name.lower()] = {'original_name': original_name, 'md5': md5_hash} | |
| except FileNotFoundError: | |
| print(f"Error: {md5_csv_path} not found. Please ensure it's in the same directory as the script or a sibling directory.") | |
| _character_md5_map = {} # Clear to prevent partial data issues | |
| except Exception as e: | |
| print(f"Error loading MD5 data: {e}") | |
| _character_md5_map = {} | |
| if not _character_thumbs_data: # Load only if not already loaded | |
| try: | |
| # Check if file exists in the current directory first, then try sibling directory if not found | |
| if not os.path.exists(thumbs_json_path) and os.path.exists(os.path.join(os.path.dirname(__file__), '..', thumbs_json_path)): | |
| thumbs_json_path = os.path.join(os.path.dirname(__file__), '..', thumbs_json_path) | |
| with open(thumbs_json_path, 'r', encoding='utf-8') as jsonfile: | |
| _character_thumbs_data = json.load(jsonfile) | |
| print(f"Loaded {len(_character_thumbs_data)} thumbnail entries from {thumbs_json_path}.") | |
| except FileNotFoundError: | |
| print(f"Error: {thumbs_json_path} not found. Please ensure it's in the same directory as the script or a sibling directory.") | |
| _character_thumbs_data = {} # Clear to prevent partial data issues | |
| except Exception as e: | |
| print(f"Error loading thumbnail data: {e}") | |
| _character_thumbs_data = {} | |
| # Load tag classifier data - assuming it always exists as per requirement | |
| try: | |
| classifier_csv_path = CLASSIFIED_TAGS_CSV | |
| if not os.path.exists(classifier_csv_path) and os.path.exists(os.path.join(os.path.dirname(__file__), '..', classifier_csv_path)): | |
| classifier_csv_path = os.path.join(os.path.dirname(__file__), '..', classifier_csv_path) | |
| _danbooru_tag_classifier_df = pd.read_csv(classifier_csv_path, index_col='name') | |
| # Drop 'tag_id' if it exists and is not needed for filtering | |
| if 'tag_id' in _danbooru_tag_classifier_df.columns: | |
| _danbooru_tag_classifier_df = _danbooru_tag_classifier_df.drop(columns=['tag_id']) | |
| except Exception as e: | |
| print(f"CRITICAL ERROR: Failed to load {CLASSIFIED_TAGS_CSV}. Tag filtering by classification will be unavailable and likely cause errors.") | |
| _danbooru_tag_classifier_df = pd.DataFrame() # Ensure it's an empty DataFrame to avoid further crashes | |
| def base64_to_pil_image(base64_data): | |
| try: | |
| compressed_data = base64.b64decode(base64_data) | |
| webp_data = gzip.decompress(compressed_data) | |
| image = Image.open(BytesIO(webp_data)).convert("RGBA") | |
| return image | |
| except Exception as e: | |
| print(f"Error decoding base64 image: {e}") | |
| return None | |
| def escape_parentheses_for_prompt(text: str) -> str: | |
| """Escapes parentheses for use in stable diffusion prompts.""" | |
| return text.replace('(', r'\(').replace(')', r'\)') | |
| def search_character_data_partial_for_app(character_name_partial, max_results): | |
| results = [] | |
| search_query_lower = character_name_partial.lower() | |
| if not search_query_lower: | |
| return [] | |
| count = 0 | |
| for lower_name, char_info in _character_md5_map.items(): | |
| if search_query_lower in lower_name: | |
| md5_hash = char_info['md5'] | |
| if md5_hash in _character_thumbs_data: | |
| pil_image = base64_to_pil_image(_character_thumbs_data[md5_hash]) | |
| if pil_image: | |
| results.append({ | |
| 'name': char_info['original_name'], | |
| 'image': pil_image, | |
| 'md5': md5_hash # Keep md5 for potential future use or debugging | |
| }) | |
| count += 1 | |
| if count >= max_results: | |
| break | |
| return results | |
| # --- WD14 Tagger Specific Functions --- | |
| def load_labels(dataframe) -> list[str]: | |
| name_series = dataframe["name"] | |
| name_series = name_series.map( | |
| lambda x: x.replace("_", " ") if x not in kaomojis else x | |
| ) | |
| tag_names = name_series.tolist() | |
| rating_indexes = list(np.where(dataframe["category"] == 9)[0]) | |
| general_indexes = list(np.where(dataframe["category"] == 0)[0]) | |
| character_indexes = list(np.where(dataframe["category"] == 4)[0]) | |
| return tag_names, rating_indexes, general_indexes, character_indexes | |
| def mcut_threshold(probs): | |
| sorted_probs = probs[probs.argsort()[::-1]] | |
| difs = sorted_probs[:-1] - sorted_probs[1:] | |
| t = difs.argmax() | |
| thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 | |
| return thresh | |
| class Predictor: | |
| def __init__(self, model_repo_default: str): # Initialize with a default model repo | |
| self.model_target_size = None | |
| self.last_loaded_repo = None | |
| self.model = None | |
| self.tag_names = None | |
| self.rating_indexes = None | |
| self.general_indexes = None | |
| self.character_indexes = None | |
| self.load_model(model_repo_default) # Load the default model at init | |
| def download_model_files(self, model_repo): | |
| csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME) | |
| model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME) | |
| return csv_path, model_path | |
| def load_model(self, model_repo): | |
| if model_repo == self.last_loaded_repo and self.model is not None: | |
| return | |
| csv_path, model_path = self.download_model_files(model_repo) | |
| tags_df = pd.read_csv(csv_path) | |
| self.tag_names, self.rating_indexes, self.general_indexes, self.character_indexes = load_labels(tags_df) | |
| # Clear previous model if any | |
| if hasattr(self, 'model') and self.model is not None: | |
| del self.model | |
| self.model = rt.InferenceSession(model_path) | |
| _, height, width, _ = self.model.get_inputs()[0].shape | |
| self.model_target_size = height | |
| print(f"WD14 model loaded successfully from {model_repo}") | |
| self.last_loaded_repo = model_repo | |
| def prepare_image(self, image: Image.Image): # Explicitly type-hint as PIL.Image.Image | |
| """ | |
| Prepares a PIL Image for the WD14 tagger model. | |
| Resizes and converts the image to the model's expected input format. | |
| """ | |
| if not isinstance(image, Image.Image): | |
| raise ValueError("Input to prepare_image must be a PIL.Image.Image object.") | |
| target_size = self.model_target_size | |
| # Convert to RGBA first to handle potential alpha channel correctly for white background | |
| image = image.convert("RGBA") | |
| canvas = Image.new("RGBA", image.size, (255, 255, 255, 255)) | |
| canvas.alpha_composite(image) | |
| image = canvas.convert("RGB") # Now convert to RGB after blending with white background | |
| image_shape = image.size | |
| max_dim = max(image_shape) | |
| pad_left = (max_dim - image_shape[0]) // 2 | |
| pad_top = (max_dim - image_shape[1]) // 2 | |
| padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) | |
| padded_image.paste(image, (pad_left, pad_top)) | |
| if max_dim != target_size: | |
| padded_image = padded_image.resize( | |
| (target_size, target_size), | |
| Image.Resampling.BICUBIC, # Use Image.Resampling | |
| ) | |
| image_array = np.asarray(padded_image, dtype=np.float32) | |
| # The model expects BGR, so reverse the channels | |
| image_array = image_array[:, :, ::-1] | |
| return np.expand_dims(image_array, axis=0) # Add batch dimension | |
| def predict( | |
| self, | |
| image: Image.Image, | |
| model_repo: str, | |
| general_thresh: float, | |
| general_mcut_enabled: bool, | |
| character_thresh: float, | |
| character_mcut_enabled: bool, | |
| ): | |
| # Load model if it's different from the currently loaded one | |
| self.load_model(model_repo) # This will skip if already loaded | |
| image_input = self.prepare_image(image) | |
| input_name = self.model.get_inputs()[0].name | |
| label_name = self.model.get_outputs()[0].name | |
| preds = self.model.run([label_name], {input_name: image_input})[0] | |
| labels = list(zip(self.tag_names, preds[0].astype(float))) | |
| general_names = [labels[i] for i in self.general_indexes] | |
| if general_mcut_enabled: | |
| general_probs = np.array([x[1] for x in general_names]) | |
| if len(general_probs) > 1: | |
| general_thresh = mcut_threshold(general_probs) | |
| general_res = [x for x in general_names if x[1] > general_thresh] | |
| general_res_dict = dict(general_res) | |
| character_names = [labels[i] for i in self.character_indexes] | |
| if character_mcut_enabled: | |
| character_probs = np.array([x[1] for x in character_names]) | |
| if len(character_probs) > 1: | |
| character_thresh = mcut_threshold(character_probs) | |
| character_thresh = max(0.15, character_thresh) | |
| character_res = [x for x in character_names if x[1] > character_thresh] | |
| character_res_dict = dict(character_res) | |
| # print(general_res_dict) | |
| # Prepare tag names for prompt (without confidence, no escaping yet) | |
| # general_tag_names = [x[0].replace('_', ' ') for x in sorted(general_res_dict.items(), key=lambda x: x[1], reverse=True)] | |
| general_tag_names = [ | |
| x[0] if x[0] in kaomojis else x[0].replace('_', ' ') | |
| for x in sorted(general_res_dict.items(), key=lambda x: x[1], reverse=True) | |
| ] | |
| character_tag_names = [x[0].replace('_', ' ') for x in sorted(character_res_dict.items(), key=lambda x: x[1], reverse=True)] | |
| return general_tag_names, character_tag_names | |
| # --- Gradio Interface Functions --- | |
| # Store search results temporarily for dropdown/gallery sync | |
| _last_search_results = [] | |
| def search_characters_gradio(character_name_partial, max_results): | |
| global _last_search_results | |
| load_character_data_for_app() # Ensure data is loaded | |
| found_chars = search_character_data_partial_for_app(character_name_partial, max_results) | |
| _last_search_results = found_chars # Store results | |
| if not found_chars: | |
| # Clear selected character states when no results are found | |
| return ( | |
| gr.Gallery(value=[], selected_index=None), # Clear gallery | |
| gr.Dropdown(choices=[], interactive=False, label="Select Character", value=None), # Clear dropdown value | |
| "", # Clear selected_char_original_name | |
| None, # Clear selected_pil_image_from_search | |
| "", # Clear selected_char_md5_hash | |
| "No characters found." # Message | |
| ) | |
| gallery_images_with_names = [] | |
| dropdown_options = [] | |
| for i, char_result in enumerate(found_chars): | |
| gallery_images_with_names.append((char_result['image'], char_result['name'])) | |
| dropdown_options.append((char_result['name'], i)) | |
| # Automatically select the first character if results are found | |
| first_char_name = found_chars[0]['name'] | |
| first_char_image = found_chars[0]['image'] | |
| first_char_md5 = found_chars[0]['md5'] | |
| return ( | |
| gr.Gallery(value=gallery_images_with_names, selected_index=0), | |
| gr.Dropdown(choices=dropdown_options, interactive=True, label="Select Character", value=0), | |
| first_char_name, | |
| first_char_image, | |
| first_char_md5, | |
| "" # Clear status message | |
| ) | |
| def get_selected_character_info_by_index(selected_index): | |
| # This function retrieves info based on the index from the dropdown or gallery | |
| if _last_search_results and selected_index is not None and 0 <= selected_index < len(_last_search_results): | |
| selected_char = _last_search_results[selected_index] | |
| return ( | |
| selected_char['name'], | |
| selected_char['image'], | |
| selected_char['md5'], | |
| # Ensure gallery selection is visually updated | |
| gr.Gallery(value=[(char['image'], char['name']) for char in _last_search_results], selected_index=selected_index) | |
| ) | |
| # If no valid selection, clear everything | |
| return "", None, "", gr.Gallery(value=[], selected_index=None) | |
| def update_dropdown_from_gallery(evt: gr.SelectData): | |
| # This function is called when an image in the gallery is clicked | |
| # evt.index gives the index of the clicked image | |
| if evt.index is not None: | |
| return evt.index | |
| return None # Return None to deselect dropdown if an empty area is clicked (though gallery itself might prevent this) | |
| def update_allowed_tags_set(selected_categories: list): | |
| """ | |
| Pre-processes the allowed tags set based on selected categories. | |
| This function is called when the tag category filter changes. | |
| """ | |
| global _preprocessed_allowed_tags_set, _last_selected_tag_categories | |
| if _danbooru_tag_classifier_df.empty or not selected_categories: | |
| _preprocessed_allowed_tags_set = set() | |
| _last_selected_tag_categories = selected_categories | |
| return | |
| # Check if categories have actually changed | |
| if selected_categories == _last_selected_tag_categories: | |
| return | |
| # print(f"Updating allowed tags set for categories: {selected_categories}") | |
| allowed_tags_set = set() | |
| for category in selected_categories: | |
| if category in _danbooru_tag_classifier_df.columns: | |
| # Get tags where the category column has a '1' | |
| # Convert to lower case for case-insensitive matching | |
| tags_in_category = _danbooru_tag_classifier_df[_danbooru_tag_classifier_df[category] == 1].index.tolist() | |
| # allowed_tags_set.update(tag.replace('_', ' ').lower() for tag in tags_in_category) | |
| allowed_tags_set.update( | |
| tag.lower() if tag in kaomojis else tag.replace('_', ' ').lower() | |
| for tag in tags_in_category | |
| ) | |
| _preprocessed_allowed_tags_set = allowed_tags_set | |
| _last_selected_tag_categories = selected_categories | |
| # print(f"Allowed tags set updated with {len(_preprocessed_allowed_tags_set)} tags.") | |
| # Unified prompt generation logic, now taking image and char_name explicitly | |
| def _generate_prompt_logic( | |
| image_to_tag: Image.Image, | |
| char_name_from_source: str, # This is the char name explicitly chosen from DB search, if applicable | |
| wd14_model_choice: str, | |
| general_thresh: float, | |
| general_mcut_enabled: bool, | |
| character_thresh: float, | |
| character_mcut_enabled: bool, | |
| banned_words_str: str, | |
| selected_tag_categories: list, | |
| preserve_char_name_and_tags: bool, # Now covers char name AND detected character tags | |
| ): | |
| if image_to_tag is None: | |
| return "Error: No image provided for tagging.", "No prompt generated." | |
| global _wd14_predictor_instance | |
| if _wd14_predictor_instance is None: | |
| # Initialize with the currently selected model from the dropdown | |
| _wd14_predictor_instance = Predictor(WD14_MODELS[wd14_model_choice]) | |
| try: | |
| general_tags, character_tags = _wd14_predictor_instance.predict( | |
| image=image_to_tag, | |
| model_repo=WD14_MODELS[wd14_model_choice], # Pass the selected model repo | |
| general_thresh=general_thresh, | |
| general_mcut_enabled=general_mcut_enabled, | |
| character_thresh=character_thresh, | |
| character_mcut_enabled=character_mcut_enabled, | |
| ) | |
| except Exception as e: | |
| return f"Error during tag generation: {e}", "Error generating prompt." | |
| # Process banned words | |
| banned_words_list = [word.strip().lower() for word in banned_words_str.split(',') if word.strip()] | |
| prompt_parts = [] | |
| tags_to_filter = [] # Tags that will be subject to category filtering | |
| # 1. Handle character name from database search (if applicable) | |
| if char_name_from_source: | |
| char_name_lower = char_name_from_source.lower() | |
| if char_name_lower not in banned_words_list: | |
| if preserve_char_name_and_tags: | |
| prompt_parts.append(char_name_from_source) | |
| else: | |
| # If not preserving, it will be treated like any other tag for filtering | |
| tags_to_filter.append(char_name_from_source) | |
| # 2. Handle WD14 detected character tags | |
| for tag in character_tags: | |
| tag_lower = tag.lower() | |
| if tag_lower not in banned_words_list: | |
| # If preserving, add to prompt_parts directly | |
| if preserve_char_name_and_tags: | |
| # Ensure it's not a duplicate of char_name_from_source if that was already added | |
| if not (char_name_from_source and tag_lower == char_name_from_source.lower()): | |
| prompt_parts.append(tag) | |
| else: | |
| tags_to_filter.append(tag) | |
| # 3. Handle WD14 detected general tags | |
| for tag in general_tags: | |
| tag_lower = tag.lower() | |
| if tag_lower not in banned_words_list: | |
| tags_to_filter.append(tag) | |
| # --- TAG CLASSIFICATION FILTERING for tags_to_filter --- | |
| filtered_categorized_tags = [] | |
| # Ensure allowed tags set is up-to-date with current category selection | |
| update_allowed_tags_set(selected_tag_categories) | |
| if _preprocessed_allowed_tags_set: # If there are active category filters | |
| for tag in tags_to_filter: | |
| if tag.lower() in _preprocessed_allowed_tags_set: | |
| filtered_categorized_tags.append(tag) | |
| else: | |
| # If no categories are selected or classifier data is not active, use all raw tags | |
| filtered_categorized_tags = tags_to_filter | |
| # --- END TAG CLASSIFICATION FILTERING --- | |
| prompt_parts.extend(filtered_categorized_tags) | |
| # Ensure uniqueness and order (optional, but good practice for prompts) | |
| final_prompt_list = [] | |
| seen_lower = set() # Use a set of lowercased tags for uniqueness check | |
| for item in prompt_parts: | |
| item_lower = item.lower() | |
| if item_lower not in seen_lower: | |
| final_prompt_list.append(item) | |
| seen_lower.add(item_lower) | |
| final_prompt = ", ".join(final_prompt_list) | |
| final_prompt_escaped = escape_parentheses_for_prompt(final_prompt) | |
| return "Prompt generated successfully!", final_prompt_escaped | |
| # Wrapper function for the "Search Character Database" tab | |
| def generate_prompt_from_search_tab( | |
| character_original_name: str, | |
| selected_pil_image_from_search: Image.Image, | |
| wd14_model_choice: str, # Added model choice | |
| general_thresh: float, | |
| general_mcut_enabled: bool, | |
| character_thresh: float, | |
| character_mcut_enabled: bool, | |
| banned_words_str: str, | |
| selected_tag_categories: list, | |
| preserve_char_name_and_tags: bool, # Now covers char name AND detected character tags | |
| ): | |
| # For the search tab, the image to tag is the selected character image | |
| image_to_tag = selected_pil_image_from_search | |
| # The character name comes from the search selection | |
| char_name_from_source = character_original_name | |
| return _generate_prompt_logic( | |
| image_to_tag=image_to_tag, | |
| char_name_from_source=char_name_from_source, | |
| wd14_model_choice=wd14_model_choice, # Pass model choice | |
| general_thresh=general_thresh, | |
| general_mcut_enabled=general_mcut_enabled, | |
| character_thresh=character_thresh, | |
| character_mcut_enabled=character_mcut_enabled, | |
| banned_words_str=banned_words_str, | |
| selected_tag_categories=selected_tag_categories, | |
| preserve_char_name_and_tags=preserve_char_name_and_tags, | |
| ) | |
| # Wrapper function for the "Upload Your Own Image" tab | |
| def generate_prompt_from_upload_tab( | |
| input_image_upload: Image.Image, | |
| wd14_model_choice: str, # Added model choice | |
| general_thresh: float, | |
| general_mcut_enabled: bool, | |
| character_thresh: float, | |
| character_mcut_enabled: bool, | |
| banned_words_str: str, | |
| selected_tag_categories: list, | |
| preserve_char_name_and_tags: bool, # Still relevant for detected char tags | |
| ): | |
| # For the upload tab, the image to tag is the uploaded image | |
| image_to_tag = input_image_upload | |
| # There is no explicit character name from a database search for the upload tab | |
| char_name_from_source = "" | |
| return _generate_prompt_logic( | |
| image_to_tag=image_to_tag, | |
| char_name_from_source=char_name_from_source, | |
| wd14_model_choice=wd14_model_choice, # Pass model choice | |
| general_thresh=general_thresh, | |
| general_mcut_enabled=general_mcut_enabled, | |
| character_thresh=character_thresh, | |
| character_mcut_enabled=character_mcut_enabled, | |
| banned_words_str=banned_words_str, | |
| selected_tag_categories=selected_tag_categories, | |
| preserve_char_name_and_tags=preserve_char_name_and_tags, | |
| ) | |
| # Initialize the predictor once globally with the default model | |
| _wd14_predictor_instance = Predictor(_wd14_selected_model_repo) | |
| # --- Gradio Interface Layout --- | |
| # Load data once at the start of the script | |
| load_character_data_for_app() | |
| # Get tag classification categories for the dropdown, if available | |
| tag_classification_categories = [] | |
| if _danbooru_tag_classifier_df is not None and not _danbooru_tag_classifier_df.empty: | |
| tag_classification_categories = [col for col in _danbooru_tag_classifier_df.columns if col not in ['tag_id']] | |
| else: | |
| print("tag classifier data not loaded or empty. Tag category filtering will be disabled.") | |
| CSS = """ | |
| #gallery { | |
| height: 300px; | |
| max-height: 520px; | |
| margin-left: 0; /* no left margin */ | |
| flex-grow: 1; | |
| } | |
| """ | |
| with gr.Blocks(title="Character Prompt Generator", css=CSS) as demo: | |
| gr.Markdown( | |
| """ | |
| # Character Prompt Generator | |
| Generate prompts for character images either by searching the database or uploading your own! | |
| """ | |
| ) | |
| # --- Shared Tagger Settings (Moved to the top for clarity) --- | |
| gr.Markdown("#### ⚙️ Tagger Settings (Apply to both methods)") | |
| with gr.Accordion("Adjust Tagging Parameters", open=False): | |
| # NEW: Model Selection Dropdown | |
| wd14_model_dropdown = gr.Dropdown( | |
| label="WD14 Tagger Model", | |
| choices=list(WD14_MODELS.keys()), | |
| value=list(WD14_MODELS.keys())[0], | |
| interactive=True, | |
| info="Select which WD14 tagger model to use. Different models may yield slightly different tags." | |
| ) | |
| with gr.Row(): | |
| general_threshold_slider = gr.Slider( | |
| minimum=0.01, maximum=0.99, value=_wd14_general_threshold, step=0.01, | |
| label="General Tags Threshold", interactive=True, | |
| info="Confidence score required for general tags to be included." | |
| ) | |
| general_mcut_checkbox = gr.Checkbox( | |
| value=_wd14_general_mcut_enabled, label="Use MCut for General Tags", interactive=True, | |
| info="MCut dynamically sets the threshold based on tag score distribution." | |
| ) | |
| with gr.Row(): | |
| character_threshold_slider = gr.Slider( | |
| minimum=0.01, maximum=0.99, value=_wd14_character_threshold, step=0.01, | |
| label="Character Tags Threshold", interactive=True, | |
| info="Confidence score required for character tags to be included." | |
| ) | |
| character_mcut_checkbox = gr.Checkbox( | |
| value=_wd14_character_mcut_enabled, label="Use MCut for Character Tags", interactive=True, | |
| info="MCut dynamically sets the threshold based on tag score distribution." | |
| ) | |
| banned_words_input = gr.Textbox( | |
| label="Exclude Tags (comma-separated, case-insensitive)", | |
| value="simple background, lowres, text", | |
| placeholder="e.g., text, watermark", | |
| interactive=True, | |
| info="Tags listed here will be removed from the final prompt." | |
| ) | |
| tag_category_filter = gr.Dropdown( | |
| label="Filter General Tags by Category", | |
| choices=tag_classification_categories, | |
| multiselect=True, | |
| interactive=bool(tag_classification_categories), | |
| value=["body", "hair", "eyes_face", "species_traits", "head_accessories", "subjects_relationship"] if tag_classification_categories else [], | |
| info="Only general tags belonging to selected categories will be included. Character tags are not affected by this filter unless 'Preserve Characters' is unchecked." | |
| ) | |
| preserve_char_name_and_tags_checkbox = gr.Checkbox( | |
| label="Preserve Character-Specific Tags (Bypass Category Filter)", | |
| value=True, | |
| interactive=True, | |
| info="If checked, the selected character's name (from search) and any detected character tags will *always* be included (unless banned), bypassing the 'Filter Tags by Category' setting. General tags are still filtered." | |
| ) | |
| # Hidden states to pass character data to prompt generation (keep these global-ish) | |
| selected_pil_image_from_search = gr.State(None) | |
| selected_char_original_name = gr.State("") | |
| selected_char_md5_hash = gr.State("") | |
| with gr.Tabs(): | |
| with gr.TabItem("1. Search Character Database"): | |
| gr.Markdown("### 🔍 Find a Character") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| search_input = gr.Textbox( | |
| label="Enter Character Name", | |
| placeholder="e.g., Hatsune, zhongli", | |
| interactive=True | |
| ) | |
| max_results_slider = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=DEFAULT_MAX_DISPLAY_RESULTS, | |
| step=1, | |
| label="Max Search Results to Display", | |
| interactive=True | |
| ) | |
| search_output_message = gr.Markdown("Enter a character name to search.") | |
| with gr.Column(scale=2): | |
| character_gallery = gr.Gallery( | |
| label="Matching Characters (Click to select)", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=3, | |
| rows=15, | |
| object_fit="contain", | |
| # height="auto", | |
| allow_preview=False, | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### ➡️ Selected Character") | |
| selected_character_dropdown = gr.Dropdown( | |
| label="Selected Character", | |
| choices=[], | |
| interactive=False, | |
| info="The character chosen from the search results above." | |
| ) | |
| generate_prompt_button_search = gr.Button("Generate Prompt for Selected Character", variant="primary") | |
| with gr.TabItem("2. Upload Your Own Image"): | |
| gr.Markdown("### ⬆️ Upload an Image to Tag") | |
| input_image_upload = gr.Image( | |
| label="Upload Image (JPG, PNG, WEBP)", | |
| type="pil", | |
| height=200, | |
| interactive=True, | |
| # info="This image will be sent to the tagger for prompt generation." | |
| ) | |
| generate_prompt_button_upload = gr.Button("Generate Prompt for Uploaded Image", variant="primary") | |
| gr.Markdown("---") | |
| gr.Markdown("#### ✨ Generated Prompt") | |
| # Outputs are now consolidated here, after the inputs and buttons | |
| generate_prompt_status = gr.Markdown( | |
| "", | |
| elem_id="shared_prompt_status" | |
| ) | |
| prompt_output = gr.Textbox( | |
| label="", | |
| placeholder="Your generated prompt will appear here...", | |
| lines=5, | |
| interactive=True, | |
| show_copy_button=True, | |
| elem_id="shared_prompt_output" | |
| ) | |
| # --- Event Handlers --- | |
| # Search Character Database Tab Interactions | |
| search_input.change( | |
| fn=search_characters_gradio, | |
| inputs=[search_input, max_results_slider], | |
| outputs=[character_gallery, selected_character_dropdown, selected_char_original_name, selected_pil_image_from_search, selected_char_md5_hash, search_output_message] | |
| ) | |
| max_results_slider.change( | |
| fn=search_characters_gradio, | |
| inputs=[search_input, max_results_slider], | |
| outputs=[character_gallery, selected_character_dropdown, selected_char_original_name, selected_pil_image_from_search, selected_char_md5_hash, search_output_message] | |
| ) | |
| selected_character_dropdown.change( | |
| fn=get_selected_character_info_by_index, | |
| inputs=[selected_character_dropdown], | |
| outputs=[selected_char_original_name, selected_pil_image_from_search, selected_char_md5_hash, character_gallery] | |
| ) | |
| character_gallery.select( | |
| fn=update_dropdown_from_gallery, | |
| inputs=None, | |
| outputs=[selected_character_dropdown] | |
| ) | |
| # Tag Category Filter (Shared Setting) | |
| tag_category_filter.change( | |
| fn=update_allowed_tags_set, | |
| inputs=[tag_category_filter], | |
| outputs=None, # This function only updates a global, doesn't return a Gradio component | |
| api_name=False # Don't expose this helper function via API | |
| ) | |
| # Also call it once at startup with the initial value to populate the global set | |
| demo.load( | |
| fn=lambda: update_allowed_tags_set(tag_category_filter.value), | |
| inputs=[], | |
| outputs=[] | |
| ) | |
| # Generate Prompt Button for "Search Character Database" tab | |
| generate_prompt_button_search.click( | |
| fn=generate_prompt_from_search_tab, | |
| inputs=[ | |
| selected_char_original_name, | |
| selected_pil_image_from_search, | |
| wd14_model_dropdown, | |
| general_threshold_slider, | |
| general_mcut_checkbox, | |
| character_threshold_slider, | |
| character_mcut_checkbox, | |
| banned_words_input, | |
| tag_category_filter, | |
| preserve_char_name_and_tags_checkbox | |
| ], | |
| outputs=[generate_prompt_status, prompt_output] | |
| ) | |
| # Generate Prompt Button for "Upload Your Own Image" tab | |
| generate_prompt_button_upload.click( | |
| fn=generate_prompt_from_upload_tab, | |
| inputs=[ | |
| input_image_upload, | |
| wd14_model_dropdown, | |
| general_threshold_slider, | |
| general_mcut_checkbox, | |
| character_threshold_slider, | |
| character_mcut_checkbox, | |
| banned_words_input, | |
| tag_category_filter, | |
| preserve_char_name_and_tags_checkbox | |
| ], | |
| outputs=[generate_prompt_status, prompt_output] | |
| ) | |
| demo.launch(debug=True, show_error=True) |