import html import re import time from dataclasses import dataclass from typing import Callable import gradio as gr from shared import model_dropdowns from shared.deepy.config import deepy_available from shared.utils.model_unload import model_unload_guard MAX_SEARCH_RESULTS = 24 SHOW_SEARCH_RESULT_TYPE_LINE = False @dataclass class ModelSelectorToolbar: search_button: gr.Button refresh_button: gr.Button unload_button: gr.Button finetune_button: gr.Button | None = None tool_row: gr.Row | None = None search_row: gr.Row | None = None search_query: gr.Textbox | None = None search_results: gr.HTML | None = None search_target: gr.Textbox | None = None search_apply_button: gr.Button | None = None search_close_button: gr.Button | None = None def create_toolbar(is_finetune_editor=False): with gr.Column(scale=2, min_width=210, elem_classes=["wangp-model-selector-tools"]): with gr.Row(elem_classes=["wangp-model-selector-tool-row"]) as tool_row: search_button = gr.Button("⌕", elem_id="wangp_model_tool_search", elem_classes=["wangp-model-selector-tool", "wangp-model-selector-tool-search"], size="sm", scale=0) finetune_button = gr.Button("✎" if is_finetune_editor else "+", elem_id="wangp_model_tool_finetune", elem_classes=["wangp-model-selector-tool", "wangp-model-selector-tool-finetune"], size="sm", scale=0) refresh_button = gr.Button("↻", elem_id="wangp_model_tool_refresh", elem_classes=["wangp-model-selector-tool", "wangp-model-selector-tool-refresh"], size="sm", scale=0) unload_button = gr.Button("⏏", elem_id="wangp_model_tool_unload", elem_classes=["wangp-model-selector-tool", "wangp-model-selector-tool-unload"], size="sm", scale=0) with gr.Row(visible=False, elem_classes=["wangp-model-selector-search-row"]) as search_row: with gr.Column(scale=1, min_width=0, elem_classes=["wangp-model-selector-search-box"]): search_query = gr.Textbox(value="", show_label=False, placeholder="Search models", elem_id="wangp_model_search_query", elem_classes=["wangp-model-selector-search-input"]) search_results = gr.HTML(value="", visible=False, elem_id="wangp_model_search_results") return ModelSelectorToolbar(search_button, refresh_button, unload_button, finetune_button=finetune_button, tool_row=tool_row, search_row=search_row, search_query=search_query, search_results=search_results) def create_search_panel(toolbar: ModelSelectorToolbar): with gr.Row(visible=False, elem_classes=["wangp-model-selector-hidden-controls"]): toolbar.search_target = gr.Textbox(value="", show_label=False, elem_id="wangp_model_search_target") toolbar.search_apply_button = gr.Button("Apply model search", elem_id="wangp_model_search_apply") toolbar.search_close_button = gr.Button("Close model search", elem_id="wangp_model_search_close") return toolbar def show_search_panel(): return gr.update(visible=False), gr.update(visible=True), gr.update(value=""), gr.update(value="", visible=False) def clear_search_panel(): return gr.update(visible=True), gr.update(visible=False), gr.update(value=""), gr.update(value="", visible=False) def _normalize_search_text(value): return re.sub(r"[^a-z0-9]+", "", str(value or "").casefold()) def _display_join(*parts): text = " ".join(str(part or "").strip() for part in parts if str(part or "").strip()) return re.sub(r"\s+", " ", text).strip() def _family_model_types(deps, dropdown_types, family): return [model_type for model_type in dropdown_types if deps.get_model_family(model_type, for_ui=True) == family] def _default_model_for_family(deps, state, dropdown_types, family): family_name = deps.families_infos[family][1] rows = sorted([(model_dropdowns.compact_name(family_name, deps.get_model_name(model_type)), model_type) for model_type in _family_model_types(deps, dropdown_types, family)], key=lambda row: row[0].casefold()) values = [model_type for _label, model_type in rows] model_type = (state or {}).get("last_model_per_family", {}).get(family, "") return model_type if model_type in values else (values[0] if values else "") def _family_hierarchy(deps, dropdown_types, family): family_name = deps.families_infos[family][1] rows = [ (model_dropdowns.compact_name(family_name, deps.get_model_name(model_type)), model_type, deps.get_parent_model_type(model_type)) for model_type in _family_model_types(deps, dropdown_types, family) ] rows.sort(key=lambda row: row[0].casefold()) return model_dropdowns.create_models_hierarchy(rows) def _default_model_for_parent(deps, state, children_by_parent, parent_model_type): children = children_by_parent.get(parent_model_type, []) values = [model_type for _label, model_type in children] model_type = (state or {}).get("last_model_per_type", {}).get(parent_model_type, "") return model_type if model_type in values else (values[0] if values else "") def _append_result(results, seen_targets, scope, label, model_type, path): if not model_type or model_type in seen_targets: return seen_targets.add(model_type) results.append({"scope": scope, "label": label, "model_type": model_type, "path": path}) def _search_results(deps, state, query): needle = _normalize_search_text(query) if len(needle) == 0: return [] dropdown_types = model_dropdowns.get_dropdown_model_types(deps) family_ids = sorted({deps.get_model_family(model_type, for_ui=True) for model_type in dropdown_types}, key=lambda family: deps.families_infos.get(family, (999, family))[0]) results, seen_targets = [], set() for family in family_ids: if family not in deps.families_infos: continue family_name = deps.families_infos[family][1] if needle in _normalize_search_text(family_name): _append_result(results, seen_targets, "Family", family_name, _default_model_for_family(deps, state, dropdown_types, family), family_name) continue parent_choices, children_by_parent = _family_hierarchy(deps, dropdown_types, family) for parent_name, parent_model_type in parent_choices: parent_label = _display_join(family_name, parent_name) if needle in _normalize_search_text(parent_label) or needle in _normalize_search_text(deps.get_model_name(parent_model_type)): _append_result(results, seen_targets, "Model", parent_label, _default_model_for_parent(deps, state, children_by_parent, parent_model_type), family_name) continue for child_name, child_model_type in children_by_parent.get(parent_model_type, []): full_name = deps.get_model_name(child_model_type) child_label = full_name if needle in _normalize_search_text(full_name) else _display_join(parent_label, child_name) if needle in _normalize_search_text(child_label) or needle in _normalize_search_text(full_name): _append_result(results, seen_targets, "Finetune", child_label, child_model_type, _display_join(family_name, parent_name)) if len(results) >= MAX_SEARCH_RESULTS: return results return results def render_search_results(deps, state, query): query = str(query or "") if len(query.strip()) == 0: return gr.update(value="", visible=False) results = _search_results(deps, state, query) if len(results) == 0: return gr.update(value="
No matching models
", visible=True) items = [] for index, result in enumerate(results): type_line = "{scope} · {path}".format(scope=html.escape(result["scope"]), path=html.escape(result["path"])) if SHOW_SEARCH_RESULT_TYPE_LINE else "" items.append( "".format( model_type=html.escape(result["model_type"], quote=True), label=html.escape(result["label"]), type_line=type_line, ) ) return gr.update(value="
" + "".join(items) + "
", visible=True) def apply_search_selection(model_type): model_type = str(model_type or "").strip() return (f"{model_type}|{time.time()}" if model_type else gr.update()), *clear_search_panel() def _prune_orphan_model_settings(state, deps): all_settings = (state or {}).get("all_settings", None) if not isinstance(all_settings, dict): return 0 orphan_model_types = [model_type for model_type in all_settings if deps.get_model_def(model_type) is None] for model_type in orphan_model_types: all_settings.pop(model_type, None) return len(orphan_model_types) def refresh_models_with_info(refresh_model_defs, refresh_model_dropdowns, state, deps_factory): try: parse_errors = refresh_model_defs() or [] except Exception as e: gr.Info(f"Unable to refresh model list: {e}") return refresh_model_dropdowns(state) pruned_count = _prune_orphan_model_settings(state, deps_factory()) prune_text = f" Removed {pruned_count} orphan model setting{'s' if pruned_count > 1 else ''}." if pruned_count > 0 else "" if len(parse_errors) > 0: gr.Info("Model list refreshed, but parsing errors were found: " + parse_errors[0] + prune_text) else: gr.Info("Model list refreshed." + prune_text) return refresh_model_dropdowns(state) def unload_models_from_ram(state, *, server_config, any_GPU_process_running, release_deepy_vram, reset_prompt_enhancer, reset_prompt_enhancer_if_requested, release_flashvsr_vram, release_seedvc_vram, release_model): with model_unload_guard(): unload_targets = _unload_targets_text(server_config) if any_GPU_process_running(state, "configuration"): gr.Info(f"Unable to unload {unload_targets} while GPU resources are allocated.") return if deepy_available(server_config): release_deepy_vram(state, clear_session_state=False, discard_runtime_snapshot=True) if "Prompt Enhancer" in unload_targets: reset_prompt_enhancer() reset_prompt_enhancer_if_requested() if "FlashVSR" in unload_targets: release_flashvsr_vram() if "SeedVC" in unload_targets: release_seedvc_vram() release_model() gr.Info(f"{unload_targets} unloaded from RAM.") def _unload_targets_text(server_config): targets = ["Models"] try: enhancer_enabled = int(server_config.get("enhancer_enabled", 0) or 0) > 0 except Exception: enhancer_enabled = False if enhancer_enabled: targets.append("Prompt Enhancer") if int(server_config.get("seedvc_mode", 0) or 0) > 0: targets.append("SeedVC") if int(server_config.get("flashvsr_mode", 0) or 0) > 0: targets.append("FlashVSR") if deepy_available(server_config): targets.append("Deepy") if len(targets) == 1: return targets[0] return ", ".join(targets[:-1]) + f", and {targets[-1]}" if len(targets) > 2 else " and ".join(targets) def bind_toolbar(toolbar: ModelSelectorToolbar, *, deps_factory: Callable, state, model_family, model_base_type_choice, model_choice, model_choice_target, refresh_form_trigger, refresh_model_defs: Callable, refresh_model_dropdowns: Callable, unload_handler: Callable): toolbar.search_button.click( fn=show_search_panel, outputs=[toolbar.tool_row, toolbar.search_row, toolbar.search_query, toolbar.search_results], show_progress="hidden", ).then(fn=None, js=focus_search_javascript(), inputs=None, outputs=None) toolbar.search_query.input( fn=lambda state_value, query: render_search_results(deps_factory(), state_value, query), inputs=[state, toolbar.search_query], outputs=[toolbar.search_results], show_progress="hidden", ) toolbar.search_apply_button.click( fn=apply_search_selection, inputs=[toolbar.search_target], outputs=[model_choice_target, toolbar.tool_row, toolbar.search_row, toolbar.search_query, toolbar.search_results], show_progress="hidden", ) toolbar.search_close_button.click( fn=clear_search_panel, outputs=[toolbar.tool_row, toolbar.search_row, toolbar.search_query, toolbar.search_results], show_progress="hidden", ) toolbar.refresh_button.click( fn=lambda state_value: refresh_models_with_info(refresh_model_defs, refresh_model_dropdowns, state_value, deps_factory), inputs=[state], outputs=[model_family, model_base_type_choice, model_choice, refresh_form_trigger], show_progress="hidden", ) toolbar.unload_button.click(fn=unload_handler, inputs=[state], outputs=None, show_progress="hidden") def focus_search_javascript(): return """ () => { const root = window.gradioApp ? window.gradioApp() : (document.querySelector("gradio-app")?.shadowRoot || document); setTimeout(() => { const input = root.querySelector("#wangp_model_search_query textarea, #wangp_model_search_query input"); if (input) input.focus(); }, 50); } """ def get_javascript(): return r""" (function () { let searchResultPointerDown = false; let activeSearchIndex = -1; function root() { if (window.gradioApp) return window.gradioApp(); const app = document.querySelector("gradio-app"); return app ? (app.shadowRoot || app) : document; } function queryInput() { return root().querySelector("#wangp_model_search_query textarea, #wangp_model_search_query input"); } function targetInput() { return root().querySelector("#wangp_model_search_target textarea, #wangp_model_search_target input"); } function applyButton() { const el = root().querySelector("#wangp_model_search_apply"); return el?.matches("button") ? el : el?.querySelector("button"); } function closeButton() { const el = root().querySelector("#wangp_model_search_close"); return el?.matches("button") ? el : el?.querySelector("button"); } function toolButton(id) { const el = root().querySelector(id); return el?.matches("button") ? el : el?.querySelector("button"); } function updateFinetuneTooltip() { const button = toolButton("#wangp_model_tool_finetune"); if (!button) return; const text = (button.textContent || "").trim(); button.dataset.wangpTooltip = text.includes("✎") ? "Edit finetune [Alt+F]" : "Create finetune [Alt+F]"; } function resultItems() { return Array.from(root().querySelectorAll("#wangp_model_search_results [data-model-type]")); } function setActive(index) { const items = resultItems(); if (!items.length) return; const bounded = Math.max(0, Math.min(index, items.length - 1)); activeSearchIndex = bounded; items.forEach((item, itemIndex) => item.classList.toggle("wangp-model-search-result-active", itemIndex === bounded)); items[bounded].scrollIntoView({ block: "nearest" }); } function activeIndex() { const items = resultItems(); if (activeSearchIndex >= 0 && activeSearchIndex < items.length) return activeSearchIndex; const index = items.findIndex((item) => item.classList.contains("wangp-model-search-result-active")); return index; } function selectModel(modelType) { const target = targetInput(); const button = applyButton(); if (!target || !button || !modelType) return; target.value = modelType; target.dispatchEvent(new Event("input", { bubbles: true })); button.click(); } function eventSearchItem(event) { const path = event.composedPath ? event.composedPath() : []; for (const node of path) { if (node?.matches?.("#wangp_model_search_results [data-model-type]")) return node; const closest = node?.closest?.("#wangp_model_search_results [data-model-type]"); if (closest) return closest; } return event.target?.closest?.("#wangp_model_search_results [data-model-type]"); } document.addEventListener("click", (event) => { const item = eventSearchItem(event); if (!item) return; event.preventDefault(); selectModel(item.dataset.modelType || ""); }); document.addEventListener("pointerdown", (event) => { if (!eventSearchItem(event)) return; searchResultPointerDown = true; setTimeout(() => searchResultPointerDown = false, 250); }); function bindSearchKeyboard() { const input = queryInput(); if (!input || input.dataset.wangpModelSearchBound === "1") return; input.dataset.wangpModelSearchBound = "1"; input.addEventListener("input", () => activeSearchIndex = -1); input.addEventListener("keydown", (event) => { if (event.key === "Escape") { event.preventDefault(); closeButton()?.click(); return; } const items = resultItems(); if (!items.length) return; if (event.key === "ArrowDown" || event.key === "Down") { event.preventDefault(); const index = activeIndex(); setActive(index < 0 ? 0 : index + 1); } else if (event.key === "ArrowUp" || event.key === "Up") { event.preventDefault(); const index = activeIndex(); setActive(index < 0 ? items.length - 1 : index - 1); } else if (event.key === "Enter") { event.preventDefault(); const index = activeIndex(); const item = items[index >= 0 ? index : 0] || items[0]; selectModel(item?.dataset?.modelType || ""); } }); input.addEventListener("blur", () => { setTimeout(() => { if (searchResultPointerDown) return; const active = root().activeElement || document.activeElement; if (active === queryInput()) return; closeButton()?.click(); }, 120); }); } document.addEventListener("keydown", (event) => { if (!event.altKey || event.ctrlKey || event.metaKey || event.shiftKey || event.repeat) return; const key = event.key.toLowerCase(); const target = key === "s" ? "#wangp_model_tool_search" : key === "r" ? "#wangp_model_tool_refresh" : key === "u" ? "#wangp_model_tool_unload" : ""; if (key === "f") { event.preventDefault(); toolButton("#wangp_model_tool_finetune")?.click(); return; } if (!target) return; event.preventDefault(); toolButton(target)?.click(); }); setInterval(() => { bindSearchKeyboard(); updateFinetuneTooltip(); }, 400); })(); """ def get_css(): return """ .wangp-model-selector-tools { position: relative; align-self: center; background: transparent !important; border: 0 !important; box-shadow: none !important; padding: 0 !important; } .wangp-model-selector-tools::before { content: ""; position: absolute; left: 0; right: 0; top: calc(50% - 0.5px); height: 1px; background: #333; opacity: 1; } .wangp-model-selector-tool-row { position: relative; z-index: 2; width: max-content; margin: 0 auto !important; gap: 6px !important; padding: 0 4px; background: var(--body-background-fill); } .wangp-model-selector-tool { position: relative; display: flex !important; align-items: center !important; justify-content: center !important; min-width: 42px !important; width: 42px !important; max-width: 42px !important; height: 32px !important; padding: 0 !important; font-size: 0 !important; line-height: 1 !important; border-radius: 6px !important; color: var(--button-secondary-text-color, var(--body-text-color)) !important; } .wangp-model-selector-tool::before { position: absolute; inset: 0; display: flex; align-items: center; justify-content: center; color: var(--button-secondary-text-color, var(--body-text-color)); font-family: "Segoe UI Symbol", "Arial Unicode MS", sans-serif; font-size: 32px; line-height: 1; } .wangp-model-selector-tool-search::before { content: "⌕"; font-size: 22px; transform: translateY(-1px); } .wangp-model-selector-tool-refresh::before { content: "↻"; font-size: 23px; transform: translateY(-1px); } .wangp-model-selector-tool-unload::before { content: "⏏"; font-size: 22px; transform: translateY(-1px); } .wangp-model-selector-tool-finetune { font-size: 24px !important; font-family: "Segoe UI Symbol", "Arial Unicode MS", sans-serif !important; transform: translateY(-1px); } .wangp-model-selector-tool-finetune::before { content: ""; } .wangp-model-selector-tool::after { position: absolute; left: 50%; bottom: calc(100% + 8px); transform: translateX(-50%); white-space: nowrap; padding: 4px 7px; border-radius: 4px; background: var(--body-text-color); color: var(--body-background-fill); font-size: 12px; line-height: 1.2; opacity: 0; pointer-events: none; transition: opacity 120ms ease; transition-delay: 0s; z-index: 50; } .wangp-model-selector-tool:hover::after { opacity: 1; transition-delay: 500ms; } .wangp-model-selector-tool-search::after { content: "Search models [Alt+S]"; } .wangp-model-selector-tool-refresh::after { content: "Refresh model list [Alt+R]"; } .wangp-model-selector-tool-finetune::after { content: attr(data-wangp-tooltip); } .wangp-model-selector-tool-unload::after { content: "Unload models and extensions [Alt+U]"; } .wangp-model-selector-search-row { --wangp-model-selector-gap: calc(16px * var(--wangp-ui-scale, 0.9)); --wangp-model-selector-search-margin: 8px; position: relative; z-index: 3; box-sizing: border-box; width: calc(100% + var(--wangp-model-selector-gap) - var(--wangp-model-selector-search-margin)); margin: 0 0 0 calc(var(--wangp-model-selector-search-margin) - var(--wangp-model-selector-gap)) !important; padding: 0; background: var(--body-background-fill); } .wangp-model-selector-search-box { position: relative; width: 100%; min-width: 0 !important; } .wangp-model-selector-search-input textarea, .wangp-model-selector-search-input input { text-align: left; min-height: 27px !important; height: 27px !important; padding: 2px 7px !important; font-size: 13px !important; border: 0 !important; box-shadow: none !important; } .wangp-model-selector-search-input { border: 1px solid var(--border-color-primary) !important; border-radius: 5px !important; background: var(--input-background-fill, var(--body-background-fill)) !important; box-sizing: border-box; width: 100%; padding: 1px !important; margin: 0 !important; } .wangp-model-selector-search-input label, .wangp-model-selector-search-input > div { margin: 0 !important; padding: 0 !important; } #wangp_model_search_results { position: absolute !important; left: auto; right: 0; top: calc(100% - 1px); width: min(560px, 90vw); z-index: 1000; } .wangp-model-search-popup { margin-top: 0; border: 1px solid var(--border-color-primary); border-radius: 0 0 6px 6px; background: var(--body-background-fill); overflow: hidden; max-height: 280px; overflow-y: auto; } .wangp-model-search-result { width: 100%; border: 0; border-bottom: """ + ("1px solid var(--border-color-primary)" if SHOW_SEARCH_RESULT_TYPE_LINE else "0") + """; background: transparent; color: var(--body-text-color); display: block; text-align: left; padding: 8px 10px; cursor: pointer; } .wangp-model-search-result:last-child { border-bottom: 0; } .wangp-model-search-result:hover, .wangp-model-search-result-active { background: var(--background-fill-secondary); background: color-mix(in srgb, var(--button-primary-background-fill) 18%, var(--body-background-fill)); box-shadow: inset 3px 0 0 var(--button-primary-background-fill); color: var(--body-text-color); } .wangp-model-search-result:hover .wangp-model-search-result-title, .wangp-model-search-result-active .wangp-model-search-result-title { font-weight: 700; } .wangp-model-search-result-title, .wangp-model-search-result-meta { display: block; overflow: hidden; text-overflow: ellipsis; white-space: nowrap; } .wangp-model-search-result-title { font-weight: 600; font-size: 13px; } .wangp-model-search-result-meta { opacity: 0.75; font-size: 11px; margin-top: 2px; } .wangp-model-search-empty { padding: 8px 10px; font-size: 12px; opacity: 0.75; } .wangp-model-selector-hidden-controls { display: none !important; } """