import os from collections import defaultdict from dataclasses import dataclass import gradio as gr MODEL_FILE_STATUS_MISSING = 0 MODEL_FILE_STATUS_PARTIAL = 1 MODEL_FILE_STATUS_EXPECTED = 2 MODEL_STATUS_PREFIXES = { MODEL_FILE_STATUS_MISSING: "\u2B1B", MODEL_FILE_STATUS_EXPECTED: "\U0001F7E6", MODEL_FILE_STATUS_PARTIAL: "\U0001F7E8", } @dataclass class DropdownDeps: transformer_types: list displayed_model_types: list transformer_type: str three_levels_hierarchy: bool families_infos: dict server_config: dict transformer_quantization: str transformer_dtype_policy: str text_encoder_quantization: str get_model_def: callable get_model_recursive_prop: callable get_model_filename: callable get_local_model_filename: callable get_lora_dir: callable get_parent_model_type: callable get_base_model_type: callable get_model_family: callable get_model_name: callable get_transformer_dtype: callable def compact_name(family_name, model_name): if model_name.startswith(family_name): return model_name[len(family_name):].strip() return model_name def decorate_model_dropdown_label(label, status): if not isinstance(label, str): return label prefix = MODEL_STATUS_PREFIXES.get(status, "") return f"{prefix} {label}" if len(prefix) > 0 else label def decorate_dropdown_choices_with_status(choices, status_map): decorated = [] for choice in choices: if not isinstance(choice, tuple) or len(choice) < 2: decorated.append(choice) continue label, value = choice[0], choice[1] status = status_map.get(value, MODEL_FILE_STATUS_MISSING) decorated.append((decorate_model_dropdown_label(label, status), value, *choice[2:])) return decorated def get_dropdown_model_types(deps): dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types) if deps.transformer_type not in dropdown_types: dropdown_types.append(deps.transformer_type) return list(dict.fromkeys(dropdown_types)) def get_family_dropdown_model_types(deps, current_model_family, dropdown_types=None): dropdown_types = get_dropdown_model_types(deps) if dropdown_types is None else dropdown_types if current_model_family is None: return dropdown_types return [model_type for model_type in dropdown_types if deps.get_model_family(model_type, for_ui=True) == current_model_family] def _get_module_files_for_status(deps, model_type, quantization, dtype_policy): transformer_dtype = deps.get_transformer_dtype(model_type, dtype_policy) modules = deps.get_model_recursive_prop(model_type, "modules", return_list=True) modules = [deps.get_model_recursive_prop(module, "modules", sub_prop_name="_list", return_list=True) if isinstance(module, str) else module for module in modules] module_files = [] for module_type in modules: if isinstance(module_type, dict): URLs1 = module_type.get("URLs", None) if URLs1 is None: return None module_files.append(deps.get_model_filename(model_type, quantization, transformer_dtype, URLs=URLs1)) URLs2 = module_type.get("URLs2", None) if URLs2 is None: return None module_files.append(deps.get_model_filename(model_type, quantization, transformer_dtype, URLs=URLs2)) else: module_files.append(deps.get_model_filename(model_type, quantization, transformer_dtype, module_type=module_type)) return module_files def _get_status_quantization_and_dtype(deps): quantization = deps.server_config.get("transformer_quantization", deps.transformer_quantization) dtype_policy = deps.server_config.get("transformer_dtype_policy", deps.transformer_dtype_policy) return quantization, dtype_policy def _append_expected_file_entry(entries, seen, filename, extra_paths=None): if not isinstance(filename, str) or len(filename) == 0: return if extra_paths is None: extra_list = [] elif isinstance(extra_paths, list): extra_list = [path for path in extra_paths if isinstance(path, str) and len(path) > 0] else: extra_list = [extra_paths] if isinstance(extra_paths, str) and len(extra_paths) > 0 else [] key = (filename.casefold(), tuple(path.casefold() for path in extra_list)) if key in seen: return seen.add(key) entries.append({"filename": filename, "extra_paths": extra_list if len(extra_list) > 0 else None}) def _append_expected_local_path_entry(entries, seen, local_path): if not isinstance(local_path, str) or len(local_path) == 0: return path_key = local_path.casefold() if path_key in seen: return seen.add(path_key) entries.append({"path": local_path}) def get_expected_core_file_entries_for_status(deps, model_type): model_def = deps.get_model_def(model_type) if model_def is None: return [] quantization, dtype_policy = _get_status_quantization_and_dtype(deps) entries = [] seen = set() expected_filename = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy) _append_expected_file_entry(entries, seen, expected_filename) if isinstance(model_def, dict) and "URLs2" in model_def: expected_filename2 = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy, submodel_no=2) _append_expected_file_entry(entries, seen, expected_filename2) module_files = _get_module_files_for_status(deps, model_type, quantization, dtype_policy) if isinstance(module_files, list): for filename in module_files: _append_expected_file_entry(entries, seen, filename) text_encoder_URLs = deps.get_model_recursive_prop(model_type, "text_encoder_URLs", return_list=True) if text_encoder_URLs is not None: text_encoder_filename = deps.get_model_filename(model_type=model_type, quantization=deps.text_encoder_quantization, dtype_policy=dtype_policy, URLs=text_encoder_URLs) text_encoder_folder = model_def.get("text_encoder_folder", None) _append_expected_file_entry(entries, seen, text_encoder_filename, extra_paths=text_encoder_folder) return entries def get_missing_core_file_entries_for_status(deps, model_type): missing_entries = [] for entry in get_expected_core_file_entries_for_status(deps, model_type): filename = entry.get("filename", "") extra_paths = entry.get("extra_paths", None) if deps.get_local_model_filename(filename, extra_paths=extra_paths) is None: missing_entries.append(entry) return missing_entries def get_expected_secondary_file_entries_for_status(deps, model_type): model_def = deps.get_model_def(model_type) if model_def is None: return [] entries = [] seen = set() preload_urls = deps.get_model_recursive_prop(model_type, "preload_URLs", return_list=True) if preload_urls is None: preload_urls = [] if not isinstance(preload_urls, list): preload_urls = [preload_urls] for url in preload_urls: if isinstance(url, str) and len(url) > 0: _append_expected_file_entry(entries, seen, url) vae_urls = model_def.get("VAE_URLs", []) if vae_urls is None: vae_urls = [] if not isinstance(vae_urls, list): vae_urls = [vae_urls] for url in vae_urls: if isinstance(url, str) and len(url) > 0: _append_expected_file_entry(entries, seen, url) model_loras = deps.get_model_recursive_prop(model_type, "loras", return_list=True) if model_loras is None: model_loras = [] if not isinstance(model_loras, list): model_loras = [model_loras] lora_dir = deps.get_lora_dir(model_type) for url in model_loras: if not isinstance(url, str) or len(url) == 0: continue basename = os.path.basename(url) if len(basename) == 0: continue _append_expected_local_path_entry(entries, seen, os.path.join(lora_dir, basename)) return entries def has_secondary_model_files_for_status(deps, model_type, quantization, dtype_policy): model_def = deps.get_model_def(model_type) if model_def is None: return True text_encoder_URLs = deps.get_model_recursive_prop(model_type, "text_encoder_URLs", return_list=True) if text_encoder_URLs is not None: text_encoder_filename = deps.get_model_filename(model_type=model_type, quantization=deps.text_encoder_quantization, dtype_policy=dtype_policy, URLs=text_encoder_URLs) if isinstance(text_encoder_filename, str) and len(text_encoder_filename) > 0: text_encoder_folder = model_def.get("text_encoder_folder", None) if deps.get_local_model_filename(text_encoder_filename, extra_paths=text_encoder_folder) is None: return False for prop, recursive in (("preload_URLs", True), ("VAE_URLs", False)): if recursive: urls = deps.get_model_recursive_prop(model_type, prop, return_list=True) else: urls = model_def.get(prop, []) if urls is None: continue if not isinstance(urls, list): urls = [urls] for url in urls: if not isinstance(url, str) or len(url) == 0: continue if deps.get_local_model_filename(url) is None: return False model_loras = deps.get_model_recursive_prop(model_type, "loras", return_list=True) if model_loras is None: model_loras = [] if not isinstance(model_loras, list): model_loras = [model_loras] lora_dir = deps.get_lora_dir(model_type) for url in model_loras: if not isinstance(url, str) or len(url) == 0: continue if not os.path.isfile(os.path.join(lora_dir, os.path.basename(url))): return False module_files = _get_module_files_for_status(deps, model_type, quantization, dtype_policy) if module_files is None: return False for filename in module_files: if not isinstance(filename, str) or len(filename) == 0: continue if deps.get_local_model_filename(filename) is None: return False return True def get_model_download_status(deps, model_type): quantization, dtype_policy = _get_status_quantization_and_dtype(deps) model_def = deps.get_model_def(model_type) expected_filenames = [] expected_filename = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy) if isinstance(expected_filename, str) and len(expected_filename) > 0: expected_filenames.append(expected_filename) if isinstance(model_def, dict) and "URLs2" in model_def: expected_filename2 = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy, submodel_no=2) if isinstance(expected_filename2, str) and len(expected_filename2) > 0: expected_filenames.append(expected_filename2) expected_exists = [] for filename in expected_filenames: expected_exists.append(deps.get_local_model_filename(filename) is not None) if len(expected_exists) > 0 and all(expected_exists): if not has_secondary_model_files_for_status(deps, model_type, quantization, dtype_policy): return MODEL_FILE_STATUS_PARTIAL return MODEL_FILE_STATUS_EXPECTED if any(expected_exists): return MODEL_FILE_STATUS_PARTIAL candidate_urls = [] for prop in ("URLs", "URLs2"): urls = deps.get_model_recursive_prop(model_type, prop, return_list=True) if not isinstance(urls, list): urls = [urls] if urls else [] candidate_urls += urls checked_candidates = set() expected_set = {name.casefold() for name in expected_filenames if isinstance(name, str) and len(name) > 0} for candidate in candidate_urls: if not isinstance(candidate, str) or len(candidate) == 0: continue candidate_key = candidate.casefold() if candidate_key in checked_candidates: continue checked_candidates.add(candidate_key) if candidate_key in expected_set: continue if deps.get_local_model_filename(candidate) is not None: return MODEL_FILE_STATUS_PARTIAL return MODEL_FILE_STATUS_MISSING def get_model_download_status_maps(deps, dropdown_types=None): direct_status_map = {} dropdown_types = get_dropdown_model_types(deps) if dropdown_types is None else dropdown_types parent_to_children = defaultdict(list) for model_type in dropdown_types: if deps.get_model_def(model_type) is None: continue status = get_model_download_status(deps, model_type) direct_status_map[model_type] = status parent_model_type = deps.get_parent_model_type(model_type) if parent_model_type is not None: parent_to_children[parent_model_type].append(model_type) aggregated_parent_status_map = dict(direct_status_map) for parent_model_type, children in parent_to_children.items(): child_statuses = [direct_status_map.get(child, MODEL_FILE_STATUS_MISSING) for child in children] if len(child_statuses) == 0: continue parent_status = MODEL_FILE_STATUS_MISSING if any(status == MODEL_FILE_STATUS_EXPECTED for status in child_statuses): parent_status = MODEL_FILE_STATUS_EXPECTED elif any(status == MODEL_FILE_STATUS_PARTIAL for status in child_statuses): parent_status = MODEL_FILE_STATUS_PARTIAL aggregated_parent_status_map[parent_model_type] = max(aggregated_parent_status_map.get(parent_model_type, MODEL_FILE_STATUS_MISSING), parent_status) return direct_status_map, aggregated_parent_status_map def get_model_download_status_map(deps, dropdown_types=None): return get_model_download_status_maps(deps, dropdown_types)[1] def create_models_hierarchy(rows): """ rows: list of (model_name, model_id, parent_model_id) returns: parents_list: list[(parent_header, parent_id)] children_dict: dict[parent_id] -> list[(child_display_name, child_id)] """ toks = lambda s: [t for t in s.split() if t] norm = lambda s: " ".join(s.split()).casefold() groups, parents, order = defaultdict(list), {}, [] for name, mid, pmid in rows: groups[pmid].append((name, mid)) if mid == pmid and pmid not in parents: parents[pmid] = name order.append(pmid) parents_list, children_dict = [], {} for pid in order: p_name = parents[pid] p_tok = toks(p_name) p_low = [w.casefold() for w in p_tok] n = len(p_low) p_last = p_low[-1] p_set = set(p_low) kids = [] for name, mid in groups.get(pid, []): ot = toks(name) lt = [w.casefold() for w in ot] st = set(lt) kids.append((name, mid, ot, lt, st)) outliers = {mid for _, mid, _, _, st in kids if mid != pid and p_set.isdisjoint(st)} prefix_non = [] for name, mid, ot, lt, st in kids: if mid == pid or (mid not in outliers and lt and lt[0] == p_low[0]): prefix_non.append((ot, lt)) def lcp_len(a, b): i = 0 m = min(len(a), len(b)) while i < m and a[i] == b[i]: i += 1 return i L = n if len(prefix_non) <= 1 else min(lcp_len(lt, p_low) for _, lt in prefix_non) if L == 0 and len(prefix_non) > 1: L = n shares_last = any(mid != pid and mid not in outliers and lt and lt[-1] == p_last for _, mid, _, lt, _ in kids) header_tokens_disp = p_tok[:L] + ([p_tok[-1]] if shares_last and L < n else []) header = " ".join(header_tokens_disp) header_has_last = (L == n) or (shares_last and L < n) prefix_low = p_low[:L] def startswith_prefix(lt): if L == 0 or len(lt) < L: return False for i in range(L): if lt[i] != prefix_low[i]: return False return True def base_rem(ot, lt): return ot[L:] if startswith_prefix(lt) else ot[:] def trim_rem(rem, lt): out = rem[:] if header_has_last and lt and lt[-1] == p_last and out and out[-1].casefold() == p_last: out = out[:-1] return out kid_infos = [] for name, mid, ot, lt, _ in kids: rem_core = base_rem(ot, lt) if mid not in outliers else ot[:] kid_infos.append({ "name": name, "mid": mid, "ot": ot, "lt": lt, "outlier": mid in outliers, "rem_core": rem_core, "rem_trim": trim_rem(rem_core, lt) if mid not in outliers else ot[:], "rem_set": {w.casefold() for w in rem_core} if mid not in outliers else set(), "rem_trim_set": {w.casefold() for w in (trim_rem(rem_core, lt) if mid not in outliers else ot[:])} if mid not in outliers else set(), }) default_info = next(info for info in kid_infos if info["mid"] == pid) other_words = set() for info in kid_infos: if info["mid"] != pid: other_words |= info["rem_set"] default_shares = bool(default_info["rem_set"] & other_words) def disp(info): if info["outlier"]: return info["name"] if info["mid"] == pid: if not default_shares: return "Default" rem = info["rem_trim"] else: rem = info["rem_trim"] s = " ".join(rem).strip() return s if s else "Default" entries = [(disp(default_info), pid)] for info in kid_infos: if info["mid"] == pid: continue entries.append((disp(info), info["mid"])) p_full = norm(p_name) full_by_mid = {mid: name for name, mid, *_ in kids} num = 2 numbered = [entries[0]] for dname, mid in entries[1:]: if dname == "Default" and norm(full_by_mid[mid]) == p_full: numbered.append((f"Default #{num}", mid)) num += 1 else: numbered.append((dname, mid)) parents_list.append((header, pid)) children_dict[pid] = numbered for pid in groups.keys(): if pid in parents: continue first_name = groups[pid][0][0] parents_list.append((first_name, pid)) children_dict[pid] = [(name, mid) for name, mid in groups[pid]] parents_list = sorted(parents_list, key=lambda c: c[0]) return parents_list, children_dict def get_sorted_dropdown(deps, dropdown_types, current_model_family, current_model_type, three_levels=True): models_families = [deps.get_model_family(t, for_ui=True) for t in dropdown_types] families = {} for family in models_families: if family not in families: families[family] = 1 families_orders = [deps.families_infos[family][0] for family in families] families_labels = [deps.families_infos[family][1] for family in families] sorted_familes = [info[1:] for info in sorted(zip(families_orders, families_labels, families), key=lambda c: c[0])] if current_model_family is None: dropdown_choices = [(deps.families_infos[family][0], deps.get_model_name(model_type), model_type) for model_type, family in zip(dropdown_types, models_families)] else: dropdown_choices = [(deps.families_infos[family][0], compact_name(deps.families_infos[family][1], deps.get_model_name(model_type)), model_type) for model_type, family in zip(dropdown_types, models_families) if family == current_model_family] dropdown_choices = sorted(dropdown_choices, key=lambda c: (c[0], c[1])) if three_levels: dropdown_choices = [(*model[1:], deps.get_parent_model_type(model[2])) for model in dropdown_choices] sorted_choices, finetunes_dict = create_models_hierarchy(dropdown_choices) return sorted_familes, sorted_choices, finetunes_dict[deps.get_parent_model_type(current_model_type)] dropdown_types_list = list({deps.get_base_model_type(model[2]) for model in dropdown_choices}) dropdown_choices = [model[1:] for model in dropdown_choices] return sorted_familes, dropdown_types_list, dropdown_choices def generate_dropdown_model_list(deps, current_model_type): dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types) if current_model_type not in dropdown_types: dropdown_types.append(current_model_type) current_model_family = deps.get_model_family(current_model_type, for_ui=True) sorted_familes, sorted_models, sorted_finetunes = get_sorted_dropdown(deps, dropdown_types, current_model_family, current_model_type, three_levels=deps.three_levels_hierarchy) status_model_types = get_family_dropdown_model_types(deps, current_model_family, dropdown_types) if current_model_type not in status_model_types: status_model_types.append(current_model_type) direct_status_map, aggregated_parent_status_map = get_model_download_status_maps(deps, status_model_types) sorted_models = decorate_dropdown_choices_with_status(sorted_models, aggregated_parent_status_map) sorted_finetunes = decorate_dropdown_choices_with_status(sorted_finetunes, direct_status_map) dropdown_families = gr.Dropdown(choices=sorted_familes, value=current_model_family, show_label=False, scale=2 if deps.three_levels_hierarchy else 1, elem_id="family_list", min_width=50) dropdown_models = gr.Dropdown(choices=sorted_models, value=deps.get_parent_model_type(current_model_type) if deps.three_levels_hierarchy else deps.get_base_model_type(current_model_type), show_label=False, scale=3 if len(sorted_finetunes) > 1 else 7, elem_id="model_base_types_list", visible=deps.three_levels_hierarchy) dropdown_finetunes = gr.Dropdown(choices=sorted_finetunes, value=current_model_type, show_label=False, scale=4, visible=len(sorted_finetunes) > 1 or not deps.three_levels_hierarchy, elem_id="model_list") return dropdown_families, dropdown_models, dropdown_finetunes def change_model_family(deps, state, current_model_family): dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types) current_family_name = deps.families_infos[current_model_family][1] models_families = [deps.get_model_family(t, for_ui=True) for t in dropdown_types] dropdown_choices = [(compact_name(current_family_name, deps.get_model_name(model_type)), model_type) for model_type, family in zip(dropdown_types, models_families) if family == current_model_family] dropdown_choices = sorted(dropdown_choices, key=lambda c: c[0]) family_dropdown_types = [choice[1] for choice in dropdown_choices] direct_status_map, aggregated_parent_status_map = get_model_download_status_maps(deps, family_dropdown_types) last_model_per_family = state.get("last_model_per_family", {}) model_type = last_model_per_family.get(current_model_family, "") if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices]: model_type = dropdown_choices[0][1] if deps.three_levels_hierarchy: parent_model_type = deps.get_parent_model_type(model_type) dropdown_choices = [(*tup, deps.get_parent_model_type(tup[1])) for tup in dropdown_choices] dropdown_base_types_choices, finetunes_dict = create_models_hierarchy(dropdown_choices) dropdown_choices = decorate_dropdown_choices_with_status(finetunes_dict[parent_model_type], direct_status_map) dropdown_base_types_choices = decorate_dropdown_choices_with_status(dropdown_base_types_choices, aggregated_parent_status_map) model_finetunes_visible = len(dropdown_choices) > 1 else: parent_model_type = deps.get_base_model_type(model_type) model_finetunes_visible = True dropdown_base_types_choices = list({deps.get_base_model_type(model[1]) for model in dropdown_choices}) dropdown_choices = decorate_dropdown_choices_with_status(dropdown_choices, direct_status_map) return gr.Dropdown(choices=dropdown_base_types_choices, value=parent_model_type, scale=3 if model_finetunes_visible else 7), gr.Dropdown(choices=dropdown_choices, value=model_type, visible=model_finetunes_visible) def change_model_base_types(deps, state, current_model_family, model_base_type_choice): if not deps.three_levels_hierarchy: return gr.update() dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types) current_family_name = deps.families_infos[current_model_family][1] dropdown_choices = [(compact_name(current_family_name, deps.get_model_name(model_type)), model_type, model_base_type_choice) for model_type in dropdown_types if deps.get_parent_model_type(model_type) == model_base_type_choice and deps.get_model_family(model_type, for_ui=True) == current_model_family] dropdown_choices = sorted(dropdown_choices, key=lambda c: c[0]) _, finetunes_dict = create_models_hierarchy(dropdown_choices) base_dropdown_types = [choice[1] for choice in dropdown_choices] direct_status_map, _ = get_model_download_status_maps(deps, base_dropdown_types) dropdown_choices = decorate_dropdown_choices_with_status(finetunes_dict[model_base_type_choice], direct_status_map) model_finetunes_visible = len(dropdown_choices) > 1 last_model_per_type = state.get("last_model_per_type", {}) model_type = last_model_per_type.get(model_base_type_choice, "") if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices]: model_type = dropdown_choices[0][1] return gr.update(scale=3 if model_finetunes_visible else 7), gr.Dropdown(choices=dropdown_choices, value=model_type, visible=model_finetunes_visible)