Wan2GP / shared /model_dropdowns.py
Egnalkram's picture
Upload folder using huggingface_hub
4689c2b verified
import os
from collections import defaultdict
from dataclasses import dataclass
import gradio as gr
MODEL_FILE_STATUS_MISSING = 0
MODEL_FILE_STATUS_PARTIAL = 1
MODEL_FILE_STATUS_EXPECTED = 2
MODEL_STATUS_PREFIXES = {
MODEL_FILE_STATUS_MISSING: "\u2B1B",
MODEL_FILE_STATUS_EXPECTED: "\U0001F7E6",
MODEL_FILE_STATUS_PARTIAL: "\U0001F7E8",
}
@dataclass
class DropdownDeps:
transformer_types: list
displayed_model_types: list
transformer_type: str
three_levels_hierarchy: bool
families_infos: dict
server_config: dict
transformer_quantization: str
transformer_dtype_policy: str
text_encoder_quantization: str
get_model_def: callable
get_model_recursive_prop: callable
get_model_filename: callable
get_local_model_filename: callable
get_lora_dir: callable
get_parent_model_type: callable
get_base_model_type: callable
get_model_family: callable
get_model_name: callable
get_transformer_dtype: callable
def compact_name(family_name, model_name):
if model_name.startswith(family_name):
return model_name[len(family_name):].strip()
return model_name
def decorate_model_dropdown_label(label, status):
if not isinstance(label, str):
return label
prefix = MODEL_STATUS_PREFIXES.get(status, "")
return f"{prefix} {label}" if len(prefix) > 0 else label
def decorate_dropdown_choices_with_status(choices, status_map):
decorated = []
for choice in choices:
if not isinstance(choice, tuple) or len(choice) < 2:
decorated.append(choice)
continue
label, value = choice[0], choice[1]
status = status_map.get(value, MODEL_FILE_STATUS_MISSING)
decorated.append((decorate_model_dropdown_label(label, status), value, *choice[2:]))
return decorated
def get_dropdown_model_types(deps):
dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types)
if deps.transformer_type not in dropdown_types:
dropdown_types.append(deps.transformer_type)
return list(dict.fromkeys(dropdown_types))
def get_family_dropdown_model_types(deps, current_model_family, dropdown_types=None):
dropdown_types = get_dropdown_model_types(deps) if dropdown_types is None else dropdown_types
if current_model_family is None:
return dropdown_types
return [model_type for model_type in dropdown_types if deps.get_model_family(model_type, for_ui=True) == current_model_family]
def _get_module_files_for_status(deps, model_type, quantization, dtype_policy):
transformer_dtype = deps.get_transformer_dtype(model_type, dtype_policy)
modules = deps.get_model_recursive_prop(model_type, "modules", return_list=True)
modules = [deps.get_model_recursive_prop(module, "modules", sub_prop_name="_list", return_list=True) if isinstance(module, str) else module for module in modules]
module_files = []
for module_type in modules:
if isinstance(module_type, dict):
URLs1 = module_type.get("URLs", None)
if URLs1 is None:
return None
module_files.append(deps.get_model_filename(model_type, quantization, transformer_dtype, URLs=URLs1))
URLs2 = module_type.get("URLs2", None)
if URLs2 is None:
return None
module_files.append(deps.get_model_filename(model_type, quantization, transformer_dtype, URLs=URLs2))
else:
module_files.append(deps.get_model_filename(model_type, quantization, transformer_dtype, module_type=module_type))
return module_files
def _get_status_quantization_and_dtype(deps):
quantization = deps.server_config.get("transformer_quantization", deps.transformer_quantization)
dtype_policy = deps.server_config.get("transformer_dtype_policy", deps.transformer_dtype_policy)
return quantization, dtype_policy
def _append_expected_file_entry(entries, seen, filename, extra_paths=None):
if not isinstance(filename, str) or len(filename) == 0:
return
if extra_paths is None:
extra_list = []
elif isinstance(extra_paths, list):
extra_list = [path for path in extra_paths if isinstance(path, str) and len(path) > 0]
else:
extra_list = [extra_paths] if isinstance(extra_paths, str) and len(extra_paths) > 0 else []
key = (filename.casefold(), tuple(path.casefold() for path in extra_list))
if key in seen:
return
seen.add(key)
entries.append({"filename": filename, "extra_paths": extra_list if len(extra_list) > 0 else None})
def _append_expected_local_path_entry(entries, seen, local_path):
if not isinstance(local_path, str) or len(local_path) == 0:
return
path_key = local_path.casefold()
if path_key in seen:
return
seen.add(path_key)
entries.append({"path": local_path})
def get_expected_core_file_entries_for_status(deps, model_type):
model_def = deps.get_model_def(model_type)
if model_def is None:
return []
quantization, dtype_policy = _get_status_quantization_and_dtype(deps)
entries = []
seen = set()
expected_filename = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy)
_append_expected_file_entry(entries, seen, expected_filename)
if isinstance(model_def, dict) and "URLs2" in model_def:
expected_filename2 = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy, submodel_no=2)
_append_expected_file_entry(entries, seen, expected_filename2)
module_files = _get_module_files_for_status(deps, model_type, quantization, dtype_policy)
if isinstance(module_files, list):
for filename in module_files:
_append_expected_file_entry(entries, seen, filename)
text_encoder_URLs = deps.get_model_recursive_prop(model_type, "text_encoder_URLs", return_list=True)
if text_encoder_URLs is not None:
text_encoder_filename = deps.get_model_filename(model_type=model_type, quantization=deps.text_encoder_quantization, dtype_policy=dtype_policy, URLs=text_encoder_URLs)
text_encoder_folder = model_def.get("text_encoder_folder", None)
_append_expected_file_entry(entries, seen, text_encoder_filename, extra_paths=text_encoder_folder)
return entries
def get_missing_core_file_entries_for_status(deps, model_type):
missing_entries = []
for entry in get_expected_core_file_entries_for_status(deps, model_type):
filename = entry.get("filename", "")
extra_paths = entry.get("extra_paths", None)
if deps.get_local_model_filename(filename, extra_paths=extra_paths) is None:
missing_entries.append(entry)
return missing_entries
def get_expected_secondary_file_entries_for_status(deps, model_type):
model_def = deps.get_model_def(model_type)
if model_def is None:
return []
entries = []
seen = set()
preload_urls = deps.get_model_recursive_prop(model_type, "preload_URLs", return_list=True)
if preload_urls is None:
preload_urls = []
if not isinstance(preload_urls, list):
preload_urls = [preload_urls]
for url in preload_urls:
if isinstance(url, str) and len(url) > 0:
_append_expected_file_entry(entries, seen, url)
vae_urls = model_def.get("VAE_URLs", [])
if vae_urls is None:
vae_urls = []
if not isinstance(vae_urls, list):
vae_urls = [vae_urls]
for url in vae_urls:
if isinstance(url, str) and len(url) > 0:
_append_expected_file_entry(entries, seen, url)
model_loras = deps.get_model_recursive_prop(model_type, "loras", return_list=True)
if model_loras is None:
model_loras = []
if not isinstance(model_loras, list):
model_loras = [model_loras]
lora_dir = deps.get_lora_dir(model_type)
for url in model_loras:
if not isinstance(url, str) or len(url) == 0:
continue
basename = os.path.basename(url)
if len(basename) == 0:
continue
_append_expected_local_path_entry(entries, seen, os.path.join(lora_dir, basename))
return entries
def has_secondary_model_files_for_status(deps, model_type, quantization, dtype_policy):
model_def = deps.get_model_def(model_type)
if model_def is None:
return True
text_encoder_URLs = deps.get_model_recursive_prop(model_type, "text_encoder_URLs", return_list=True)
if text_encoder_URLs is not None:
text_encoder_filename = deps.get_model_filename(model_type=model_type, quantization=deps.text_encoder_quantization, dtype_policy=dtype_policy, URLs=text_encoder_URLs)
if isinstance(text_encoder_filename, str) and len(text_encoder_filename) > 0:
text_encoder_folder = model_def.get("text_encoder_folder", None)
if deps.get_local_model_filename(text_encoder_filename, extra_paths=text_encoder_folder) is None:
return False
for prop, recursive in (("preload_URLs", True), ("VAE_URLs", False)):
if recursive:
urls = deps.get_model_recursive_prop(model_type, prop, return_list=True)
else:
urls = model_def.get(prop, [])
if urls is None:
continue
if not isinstance(urls, list):
urls = [urls]
for url in urls:
if not isinstance(url, str) or len(url) == 0:
continue
if deps.get_local_model_filename(url) is None:
return False
model_loras = deps.get_model_recursive_prop(model_type, "loras", return_list=True)
if model_loras is None:
model_loras = []
if not isinstance(model_loras, list):
model_loras = [model_loras]
lora_dir = deps.get_lora_dir(model_type)
for url in model_loras:
if not isinstance(url, str) or len(url) == 0:
continue
if not os.path.isfile(os.path.join(lora_dir, os.path.basename(url))):
return False
module_files = _get_module_files_for_status(deps, model_type, quantization, dtype_policy)
if module_files is None:
return False
for filename in module_files:
if not isinstance(filename, str) or len(filename) == 0:
continue
if deps.get_local_model_filename(filename) is None:
return False
return True
def get_model_download_status(deps, model_type):
quantization, dtype_policy = _get_status_quantization_and_dtype(deps)
model_def = deps.get_model_def(model_type)
expected_filenames = []
expected_filename = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy)
if isinstance(expected_filename, str) and len(expected_filename) > 0:
expected_filenames.append(expected_filename)
if isinstance(model_def, dict) and "URLs2" in model_def:
expected_filename2 = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy, submodel_no=2)
if isinstance(expected_filename2, str) and len(expected_filename2) > 0:
expected_filenames.append(expected_filename2)
expected_exists = []
for filename in expected_filenames:
expected_exists.append(deps.get_local_model_filename(filename) is not None)
if len(expected_exists) > 0 and all(expected_exists):
if not has_secondary_model_files_for_status(deps, model_type, quantization, dtype_policy):
return MODEL_FILE_STATUS_PARTIAL
return MODEL_FILE_STATUS_EXPECTED
if any(expected_exists):
return MODEL_FILE_STATUS_PARTIAL
candidate_urls = []
for prop in ("URLs", "URLs2"):
urls = deps.get_model_recursive_prop(model_type, prop, return_list=True)
if not isinstance(urls, list):
urls = [urls] if urls else []
candidate_urls += urls
checked_candidates = set()
expected_set = {name.casefold() for name in expected_filenames if isinstance(name, str) and len(name) > 0}
for candidate in candidate_urls:
if not isinstance(candidate, str) or len(candidate) == 0:
continue
candidate_key = candidate.casefold()
if candidate_key in checked_candidates:
continue
checked_candidates.add(candidate_key)
if candidate_key in expected_set:
continue
if deps.get_local_model_filename(candidate) is not None:
return MODEL_FILE_STATUS_PARTIAL
return MODEL_FILE_STATUS_MISSING
def get_model_download_status_maps(deps, dropdown_types=None):
direct_status_map = {}
dropdown_types = get_dropdown_model_types(deps) if dropdown_types is None else dropdown_types
parent_to_children = defaultdict(list)
for model_type in dropdown_types:
if deps.get_model_def(model_type) is None:
continue
status = get_model_download_status(deps, model_type)
direct_status_map[model_type] = status
parent_model_type = deps.get_parent_model_type(model_type)
if parent_model_type is not None:
parent_to_children[parent_model_type].append(model_type)
aggregated_parent_status_map = dict(direct_status_map)
for parent_model_type, children in parent_to_children.items():
child_statuses = [direct_status_map.get(child, MODEL_FILE_STATUS_MISSING) for child in children]
if len(child_statuses) == 0:
continue
parent_status = MODEL_FILE_STATUS_MISSING
if any(status == MODEL_FILE_STATUS_EXPECTED for status in child_statuses):
parent_status = MODEL_FILE_STATUS_EXPECTED
elif any(status == MODEL_FILE_STATUS_PARTIAL for status in child_statuses):
parent_status = MODEL_FILE_STATUS_PARTIAL
aggregated_parent_status_map[parent_model_type] = max(aggregated_parent_status_map.get(parent_model_type, MODEL_FILE_STATUS_MISSING), parent_status)
return direct_status_map, aggregated_parent_status_map
def get_model_download_status_map(deps, dropdown_types=None):
return get_model_download_status_maps(deps, dropdown_types)[1]
def create_models_hierarchy(rows):
"""
rows: list of (model_name, model_id, parent_model_id)
returns:
parents_list: list[(parent_header, parent_id)]
children_dict: dict[parent_id] -> list[(child_display_name, child_id)]
"""
toks = lambda s: [t for t in s.split() if t]
norm = lambda s: " ".join(s.split()).casefold()
groups, parents, order = defaultdict(list), {}, []
for name, mid, pmid in rows:
groups[pmid].append((name, mid))
if mid == pmid and pmid not in parents:
parents[pmid] = name
order.append(pmid)
parents_list, children_dict = [], {}
for pid in order:
p_name = parents[pid]
p_tok = toks(p_name)
p_low = [w.casefold() for w in p_tok]
n = len(p_low)
p_last = p_low[-1]
p_set = set(p_low)
kids = []
for name, mid in groups.get(pid, []):
ot = toks(name)
lt = [w.casefold() for w in ot]
st = set(lt)
kids.append((name, mid, ot, lt, st))
outliers = {mid for _, mid, _, _, st in kids if mid != pid and p_set.isdisjoint(st)}
prefix_non = []
for name, mid, ot, lt, st in kids:
if mid == pid or (mid not in outliers and lt and lt[0] == p_low[0]):
prefix_non.append((ot, lt))
def lcp_len(a, b):
i = 0
m = min(len(a), len(b))
while i < m and a[i] == b[i]:
i += 1
return i
L = n if len(prefix_non) <= 1 else min(lcp_len(lt, p_low) for _, lt in prefix_non)
if L == 0 and len(prefix_non) > 1:
L = n
shares_last = any(mid != pid and mid not in outliers and lt and lt[-1] == p_last for _, mid, _, lt, _ in kids)
header_tokens_disp = p_tok[:L] + ([p_tok[-1]] if shares_last and L < n else [])
header = " ".join(header_tokens_disp)
header_has_last = (L == n) or (shares_last and L < n)
prefix_low = p_low[:L]
def startswith_prefix(lt):
if L == 0 or len(lt) < L:
return False
for i in range(L):
if lt[i] != prefix_low[i]:
return False
return True
def base_rem(ot, lt):
return ot[L:] if startswith_prefix(lt) else ot[:]
def trim_rem(rem, lt):
out = rem[:]
if header_has_last and lt and lt[-1] == p_last and out and out[-1].casefold() == p_last:
out = out[:-1]
return out
kid_infos = []
for name, mid, ot, lt, _ in kids:
rem_core = base_rem(ot, lt) if mid not in outliers else ot[:]
kid_infos.append({
"name": name,
"mid": mid,
"ot": ot,
"lt": lt,
"outlier": mid in outliers,
"rem_core": rem_core,
"rem_trim": trim_rem(rem_core, lt) if mid not in outliers else ot[:],
"rem_set": {w.casefold() for w in rem_core} if mid not in outliers else set(),
"rem_trim_set": {w.casefold() for w in (trim_rem(rem_core, lt) if mid not in outliers else ot[:])} if mid not in outliers else set(),
})
default_info = next(info for info in kid_infos if info["mid"] == pid)
other_words = set()
for info in kid_infos:
if info["mid"] != pid:
other_words |= info["rem_set"]
default_shares = bool(default_info["rem_set"] & other_words)
def disp(info):
if info["outlier"]:
return info["name"]
if info["mid"] == pid:
if not default_shares:
return "Default"
rem = info["rem_trim"]
else:
rem = info["rem_trim"]
s = " ".join(rem).strip()
return s if s else "Default"
entries = [(disp(default_info), pid)]
for info in kid_infos:
if info["mid"] == pid:
continue
entries.append((disp(info), info["mid"]))
p_full = norm(p_name)
full_by_mid = {mid: name for name, mid, *_ in kids}
num = 2
numbered = [entries[0]]
for dname, mid in entries[1:]:
if dname == "Default" and norm(full_by_mid[mid]) == p_full:
numbered.append((f"Default #{num}", mid))
num += 1
else:
numbered.append((dname, mid))
parents_list.append((header, pid))
children_dict[pid] = numbered
for pid in groups.keys():
if pid in parents:
continue
first_name = groups[pid][0][0]
parents_list.append((first_name, pid))
children_dict[pid] = [(name, mid) for name, mid in groups[pid]]
parents_list = sorted(parents_list, key=lambda c: c[0])
return parents_list, children_dict
def get_sorted_dropdown(deps, dropdown_types, current_model_family, current_model_type, three_levels=True):
models_families = [deps.get_model_family(t, for_ui=True) for t in dropdown_types]
families = {}
for family in models_families:
if family not in families:
families[family] = 1
families_orders = [deps.families_infos[family][0] for family in families]
families_labels = [deps.families_infos[family][1] for family in families]
sorted_familes = [info[1:] for info in sorted(zip(families_orders, families_labels, families), key=lambda c: c[0])]
if current_model_family is None:
dropdown_choices = [(deps.families_infos[family][0], deps.get_model_name(model_type), model_type) for model_type, family in zip(dropdown_types, models_families)]
else:
dropdown_choices = [(deps.families_infos[family][0], compact_name(deps.families_infos[family][1], deps.get_model_name(model_type)), model_type) for model_type, family in zip(dropdown_types, models_families) if family == current_model_family]
dropdown_choices = sorted(dropdown_choices, key=lambda c: (c[0], c[1]))
if three_levels:
dropdown_choices = [(*model[1:], deps.get_parent_model_type(model[2])) for model in dropdown_choices]
sorted_choices, finetunes_dict = create_models_hierarchy(dropdown_choices)
return sorted_familes, sorted_choices, finetunes_dict[deps.get_parent_model_type(current_model_type)]
dropdown_types_list = list({deps.get_base_model_type(model[2]) for model in dropdown_choices})
dropdown_choices = [model[1:] for model in dropdown_choices]
return sorted_familes, dropdown_types_list, dropdown_choices
def generate_dropdown_model_list(deps, current_model_type):
dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types)
if current_model_type not in dropdown_types:
dropdown_types.append(current_model_type)
current_model_family = deps.get_model_family(current_model_type, for_ui=True)
sorted_familes, sorted_models, sorted_finetunes = get_sorted_dropdown(deps, dropdown_types, current_model_family, current_model_type, three_levels=deps.three_levels_hierarchy)
status_model_types = get_family_dropdown_model_types(deps, current_model_family, dropdown_types)
if current_model_type not in status_model_types:
status_model_types.append(current_model_type)
direct_status_map, aggregated_parent_status_map = get_model_download_status_maps(deps, status_model_types)
sorted_models = decorate_dropdown_choices_with_status(sorted_models, aggregated_parent_status_map)
sorted_finetunes = decorate_dropdown_choices_with_status(sorted_finetunes, direct_status_map)
dropdown_families = gr.Dropdown(choices=sorted_familes, value=current_model_family, show_label=False, scale=2 if deps.three_levels_hierarchy else 1, elem_id="family_list", min_width=50)
dropdown_models = gr.Dropdown(choices=sorted_models, value=deps.get_parent_model_type(current_model_type) if deps.three_levels_hierarchy else deps.get_base_model_type(current_model_type), show_label=False, scale=3 if len(sorted_finetunes) > 1 else 7, elem_id="model_base_types_list", visible=deps.three_levels_hierarchy)
dropdown_finetunes = gr.Dropdown(choices=sorted_finetunes, value=current_model_type, show_label=False, scale=4, visible=len(sorted_finetunes) > 1 or not deps.three_levels_hierarchy, elem_id="model_list")
return dropdown_families, dropdown_models, dropdown_finetunes
def change_model_family(deps, state, current_model_family):
dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types)
current_family_name = deps.families_infos[current_model_family][1]
models_families = [deps.get_model_family(t, for_ui=True) for t in dropdown_types]
dropdown_choices = [(compact_name(current_family_name, deps.get_model_name(model_type)), model_type) for model_type, family in zip(dropdown_types, models_families) if family == current_model_family]
dropdown_choices = sorted(dropdown_choices, key=lambda c: c[0])
family_dropdown_types = [choice[1] for choice in dropdown_choices]
direct_status_map, aggregated_parent_status_map = get_model_download_status_maps(deps, family_dropdown_types)
last_model_per_family = state.get("last_model_per_family", {})
model_type = last_model_per_family.get(current_model_family, "")
if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices]:
model_type = dropdown_choices[0][1]
if deps.three_levels_hierarchy:
parent_model_type = deps.get_parent_model_type(model_type)
dropdown_choices = [(*tup, deps.get_parent_model_type(tup[1])) for tup in dropdown_choices]
dropdown_base_types_choices, finetunes_dict = create_models_hierarchy(dropdown_choices)
dropdown_choices = decorate_dropdown_choices_with_status(finetunes_dict[parent_model_type], direct_status_map)
dropdown_base_types_choices = decorate_dropdown_choices_with_status(dropdown_base_types_choices, aggregated_parent_status_map)
model_finetunes_visible = len(dropdown_choices) > 1
else:
parent_model_type = deps.get_base_model_type(model_type)
model_finetunes_visible = True
dropdown_base_types_choices = list({deps.get_base_model_type(model[1]) for model in dropdown_choices})
dropdown_choices = decorate_dropdown_choices_with_status(dropdown_choices, direct_status_map)
return gr.Dropdown(choices=dropdown_base_types_choices, value=parent_model_type, scale=3 if model_finetunes_visible else 7), gr.Dropdown(choices=dropdown_choices, value=model_type, visible=model_finetunes_visible)
def change_model_base_types(deps, state, current_model_family, model_base_type_choice):
if not deps.three_levels_hierarchy:
return gr.update()
dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types)
current_family_name = deps.families_infos[current_model_family][1]
dropdown_choices = [(compact_name(current_family_name, deps.get_model_name(model_type)), model_type, model_base_type_choice) for model_type in dropdown_types if deps.get_parent_model_type(model_type) == model_base_type_choice and deps.get_model_family(model_type, for_ui=True) == current_model_family]
dropdown_choices = sorted(dropdown_choices, key=lambda c: c[0])
_, finetunes_dict = create_models_hierarchy(dropdown_choices)
base_dropdown_types = [choice[1] for choice in dropdown_choices]
direct_status_map, _ = get_model_download_status_maps(deps, base_dropdown_types)
dropdown_choices = decorate_dropdown_choices_with_status(finetunes_dict[model_base_type_choice], direct_status_map)
model_finetunes_visible = len(dropdown_choices) > 1
last_model_per_type = state.get("last_model_per_type", {})
model_type = last_model_per_type.get(model_base_type_choice, "")
if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices]:
model_type = dropdown_choices[0][1]
return gr.update(scale=3 if model_finetunes_visible else 7), gr.Dropdown(choices=dropdown_choices, value=model_type, visible=model_finetunes_visible)