| import os
|
| from collections import defaultdict
|
| from dataclasses import dataclass
|
|
|
| import gradio as gr
|
|
|
|
|
| MODEL_FILE_STATUS_MISSING = 0
|
| MODEL_FILE_STATUS_PARTIAL = 1
|
| MODEL_FILE_STATUS_EXPECTED = 2
|
| MODEL_STATUS_PREFIXES = {
|
| MODEL_FILE_STATUS_MISSING: "\u2B1B",
|
| MODEL_FILE_STATUS_EXPECTED: "\U0001F7E6",
|
| MODEL_FILE_STATUS_PARTIAL: "\U0001F7E8",
|
| }
|
|
|
|
|
| @dataclass
|
| class DropdownDeps:
|
| transformer_types: list
|
| displayed_model_types: list
|
| transformer_type: str
|
| three_levels_hierarchy: bool
|
| families_infos: dict
|
| server_config: dict
|
| transformer_quantization: str
|
| transformer_dtype_policy: str
|
| text_encoder_quantization: str
|
| get_model_def: callable
|
| get_model_recursive_prop: callable
|
| get_model_filename: callable
|
| get_local_model_filename: callable
|
| get_lora_dir: callable
|
| get_parent_model_type: callable
|
| get_base_model_type: callable
|
| get_model_family: callable
|
| get_model_name: callable
|
| get_transformer_dtype: callable
|
|
|
|
|
| def compact_name(family_name, model_name):
|
| if model_name.startswith(family_name):
|
| return model_name[len(family_name):].strip()
|
| return model_name
|
|
|
|
|
| def decorate_model_dropdown_label(label, status):
|
| if not isinstance(label, str):
|
| return label
|
| prefix = MODEL_STATUS_PREFIXES.get(status, "")
|
| return f"{prefix} {label}" if len(prefix) > 0 else label
|
|
|
|
|
| def decorate_dropdown_choices_with_status(choices, status_map):
|
| decorated = []
|
| for choice in choices:
|
| if not isinstance(choice, tuple) or len(choice) < 2:
|
| decorated.append(choice)
|
| continue
|
| label, value = choice[0], choice[1]
|
| status = status_map.get(value, MODEL_FILE_STATUS_MISSING)
|
| decorated.append((decorate_model_dropdown_label(label, status), value, *choice[2:]))
|
| return decorated
|
|
|
|
|
| def get_dropdown_model_types(deps):
|
| dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types)
|
| if deps.transformer_type not in dropdown_types:
|
| dropdown_types.append(deps.transformer_type)
|
| return list(dict.fromkeys(dropdown_types))
|
|
|
|
|
| def get_family_dropdown_model_types(deps, current_model_family, dropdown_types=None):
|
| dropdown_types = get_dropdown_model_types(deps) if dropdown_types is None else dropdown_types
|
| if current_model_family is None:
|
| return dropdown_types
|
| return [model_type for model_type in dropdown_types if deps.get_model_family(model_type, for_ui=True) == current_model_family]
|
|
|
|
|
| def _get_module_files_for_status(deps, model_type, quantization, dtype_policy):
|
| transformer_dtype = deps.get_transformer_dtype(model_type, dtype_policy)
|
| modules = deps.get_model_recursive_prop(model_type, "modules", return_list=True)
|
| modules = [deps.get_model_recursive_prop(module, "modules", sub_prop_name="_list", return_list=True) if isinstance(module, str) else module for module in modules]
|
| module_files = []
|
| for module_type in modules:
|
| if isinstance(module_type, dict):
|
| URLs1 = module_type.get("URLs", None)
|
| if URLs1 is None:
|
| return None
|
| module_files.append(deps.get_model_filename(model_type, quantization, transformer_dtype, URLs=URLs1))
|
| URLs2 = module_type.get("URLs2", None)
|
| if URLs2 is None:
|
| return None
|
| module_files.append(deps.get_model_filename(model_type, quantization, transformer_dtype, URLs=URLs2))
|
| else:
|
| module_files.append(deps.get_model_filename(model_type, quantization, transformer_dtype, module_type=module_type))
|
| return module_files
|
|
|
|
|
| def _get_status_quantization_and_dtype(deps):
|
| quantization = deps.server_config.get("transformer_quantization", deps.transformer_quantization)
|
| dtype_policy = deps.server_config.get("transformer_dtype_policy", deps.transformer_dtype_policy)
|
| return quantization, dtype_policy
|
|
|
|
|
| def _append_expected_file_entry(entries, seen, filename, extra_paths=None):
|
| if not isinstance(filename, str) or len(filename) == 0:
|
| return
|
| if extra_paths is None:
|
| extra_list = []
|
| elif isinstance(extra_paths, list):
|
| extra_list = [path for path in extra_paths if isinstance(path, str) and len(path) > 0]
|
| else:
|
| extra_list = [extra_paths] if isinstance(extra_paths, str) and len(extra_paths) > 0 else []
|
| key = (filename.casefold(), tuple(path.casefold() for path in extra_list))
|
| if key in seen:
|
| return
|
| seen.add(key)
|
| entries.append({"filename": filename, "extra_paths": extra_list if len(extra_list) > 0 else None})
|
|
|
|
|
| def _append_expected_local_path_entry(entries, seen, local_path):
|
| if not isinstance(local_path, str) or len(local_path) == 0:
|
| return
|
| path_key = local_path.casefold()
|
| if path_key in seen:
|
| return
|
| seen.add(path_key)
|
| entries.append({"path": local_path})
|
|
|
|
|
| def get_expected_core_file_entries_for_status(deps, model_type):
|
| model_def = deps.get_model_def(model_type)
|
| if model_def is None:
|
| return []
|
| quantization, dtype_policy = _get_status_quantization_and_dtype(deps)
|
| entries = []
|
| seen = set()
|
|
|
| expected_filename = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy)
|
| _append_expected_file_entry(entries, seen, expected_filename)
|
| if isinstance(model_def, dict) and "URLs2" in model_def:
|
| expected_filename2 = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy, submodel_no=2)
|
| _append_expected_file_entry(entries, seen, expected_filename2)
|
|
|
| module_files = _get_module_files_for_status(deps, model_type, quantization, dtype_policy)
|
| if isinstance(module_files, list):
|
| for filename in module_files:
|
| _append_expected_file_entry(entries, seen, filename)
|
|
|
| text_encoder_URLs = deps.get_model_recursive_prop(model_type, "text_encoder_URLs", return_list=True)
|
| if text_encoder_URLs is not None:
|
| text_encoder_filename = deps.get_model_filename(model_type=model_type, quantization=deps.text_encoder_quantization, dtype_policy=dtype_policy, URLs=text_encoder_URLs)
|
| text_encoder_folder = model_def.get("text_encoder_folder", None)
|
| _append_expected_file_entry(entries, seen, text_encoder_filename, extra_paths=text_encoder_folder)
|
| return entries
|
|
|
|
|
| def get_missing_core_file_entries_for_status(deps, model_type):
|
| missing_entries = []
|
| for entry in get_expected_core_file_entries_for_status(deps, model_type):
|
| filename = entry.get("filename", "")
|
| extra_paths = entry.get("extra_paths", None)
|
| if deps.get_local_model_filename(filename, extra_paths=extra_paths) is None:
|
| missing_entries.append(entry)
|
| return missing_entries
|
|
|
|
|
| def get_expected_secondary_file_entries_for_status(deps, model_type):
|
| model_def = deps.get_model_def(model_type)
|
| if model_def is None:
|
| return []
|
| entries = []
|
| seen = set()
|
|
|
| preload_urls = deps.get_model_recursive_prop(model_type, "preload_URLs", return_list=True)
|
| if preload_urls is None:
|
| preload_urls = []
|
| if not isinstance(preload_urls, list):
|
| preload_urls = [preload_urls]
|
| for url in preload_urls:
|
| if isinstance(url, str) and len(url) > 0:
|
| _append_expected_file_entry(entries, seen, url)
|
|
|
| vae_urls = model_def.get("VAE_URLs", [])
|
| if vae_urls is None:
|
| vae_urls = []
|
| if not isinstance(vae_urls, list):
|
| vae_urls = [vae_urls]
|
| for url in vae_urls:
|
| if isinstance(url, str) and len(url) > 0:
|
| _append_expected_file_entry(entries, seen, url)
|
|
|
| model_loras = deps.get_model_recursive_prop(model_type, "loras", return_list=True)
|
| if model_loras is None:
|
| model_loras = []
|
| if not isinstance(model_loras, list):
|
| model_loras = [model_loras]
|
| lora_dir = deps.get_lora_dir(model_type)
|
| for url in model_loras:
|
| if not isinstance(url, str) or len(url) == 0:
|
| continue
|
| basename = os.path.basename(url)
|
| if len(basename) == 0:
|
| continue
|
| _append_expected_local_path_entry(entries, seen, os.path.join(lora_dir, basename))
|
|
|
| return entries
|
|
|
|
|
| def has_secondary_model_files_for_status(deps, model_type, quantization, dtype_policy):
|
| model_def = deps.get_model_def(model_type)
|
| if model_def is None:
|
| return True
|
|
|
| text_encoder_URLs = deps.get_model_recursive_prop(model_type, "text_encoder_URLs", return_list=True)
|
| if text_encoder_URLs is not None:
|
| text_encoder_filename = deps.get_model_filename(model_type=model_type, quantization=deps.text_encoder_quantization, dtype_policy=dtype_policy, URLs=text_encoder_URLs)
|
| if isinstance(text_encoder_filename, str) and len(text_encoder_filename) > 0:
|
| text_encoder_folder = model_def.get("text_encoder_folder", None)
|
| if deps.get_local_model_filename(text_encoder_filename, extra_paths=text_encoder_folder) is None:
|
| return False
|
|
|
| for prop, recursive in (("preload_URLs", True), ("VAE_URLs", False)):
|
| if recursive:
|
| urls = deps.get_model_recursive_prop(model_type, prop, return_list=True)
|
| else:
|
| urls = model_def.get(prop, [])
|
| if urls is None:
|
| continue
|
| if not isinstance(urls, list):
|
| urls = [urls]
|
| for url in urls:
|
| if not isinstance(url, str) or len(url) == 0:
|
| continue
|
| if deps.get_local_model_filename(url) is None:
|
| return False
|
|
|
| model_loras = deps.get_model_recursive_prop(model_type, "loras", return_list=True)
|
| if model_loras is None:
|
| model_loras = []
|
| if not isinstance(model_loras, list):
|
| model_loras = [model_loras]
|
| lora_dir = deps.get_lora_dir(model_type)
|
| for url in model_loras:
|
| if not isinstance(url, str) or len(url) == 0:
|
| continue
|
| if not os.path.isfile(os.path.join(lora_dir, os.path.basename(url))):
|
| return False
|
|
|
| module_files = _get_module_files_for_status(deps, model_type, quantization, dtype_policy)
|
| if module_files is None:
|
| return False
|
| for filename in module_files:
|
| if not isinstance(filename, str) or len(filename) == 0:
|
| continue
|
| if deps.get_local_model_filename(filename) is None:
|
| return False
|
| return True
|
|
|
|
|
| def get_model_download_status(deps, model_type):
|
| quantization, dtype_policy = _get_status_quantization_and_dtype(deps)
|
| model_def = deps.get_model_def(model_type)
|
| expected_filenames = []
|
| expected_filename = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy)
|
| if isinstance(expected_filename, str) and len(expected_filename) > 0:
|
| expected_filenames.append(expected_filename)
|
| if isinstance(model_def, dict) and "URLs2" in model_def:
|
| expected_filename2 = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy, submodel_no=2)
|
| if isinstance(expected_filename2, str) and len(expected_filename2) > 0:
|
| expected_filenames.append(expected_filename2)
|
|
|
| expected_exists = []
|
| for filename in expected_filenames:
|
| expected_exists.append(deps.get_local_model_filename(filename) is not None)
|
|
|
| if len(expected_exists) > 0 and all(expected_exists):
|
| if not has_secondary_model_files_for_status(deps, model_type, quantization, dtype_policy):
|
| return MODEL_FILE_STATUS_PARTIAL
|
| return MODEL_FILE_STATUS_EXPECTED
|
|
|
| if any(expected_exists):
|
| return MODEL_FILE_STATUS_PARTIAL
|
|
|
| candidate_urls = []
|
| for prop in ("URLs", "URLs2"):
|
| urls = deps.get_model_recursive_prop(model_type, prop, return_list=True)
|
| if not isinstance(urls, list):
|
| urls = [urls] if urls else []
|
| candidate_urls += urls
|
|
|
| checked_candidates = set()
|
| expected_set = {name.casefold() for name in expected_filenames if isinstance(name, str) and len(name) > 0}
|
| for candidate in candidate_urls:
|
| if not isinstance(candidate, str) or len(candidate) == 0:
|
| continue
|
| candidate_key = candidate.casefold()
|
| if candidate_key in checked_candidates:
|
| continue
|
| checked_candidates.add(candidate_key)
|
| if candidate_key in expected_set:
|
| continue
|
| if deps.get_local_model_filename(candidate) is not None:
|
| return MODEL_FILE_STATUS_PARTIAL
|
| return MODEL_FILE_STATUS_MISSING
|
|
|
|
|
| def get_model_download_status_maps(deps, dropdown_types=None):
|
| direct_status_map = {}
|
| dropdown_types = get_dropdown_model_types(deps) if dropdown_types is None else dropdown_types
|
| parent_to_children = defaultdict(list)
|
|
|
| for model_type in dropdown_types:
|
| if deps.get_model_def(model_type) is None:
|
| continue
|
| status = get_model_download_status(deps, model_type)
|
| direct_status_map[model_type] = status
|
| parent_model_type = deps.get_parent_model_type(model_type)
|
| if parent_model_type is not None:
|
| parent_to_children[parent_model_type].append(model_type)
|
|
|
| aggregated_parent_status_map = dict(direct_status_map)
|
| for parent_model_type, children in parent_to_children.items():
|
| child_statuses = [direct_status_map.get(child, MODEL_FILE_STATUS_MISSING) for child in children]
|
| if len(child_statuses) == 0:
|
| continue
|
| parent_status = MODEL_FILE_STATUS_MISSING
|
| if any(status == MODEL_FILE_STATUS_EXPECTED for status in child_statuses):
|
| parent_status = MODEL_FILE_STATUS_EXPECTED
|
| elif any(status == MODEL_FILE_STATUS_PARTIAL for status in child_statuses):
|
| parent_status = MODEL_FILE_STATUS_PARTIAL
|
| aggregated_parent_status_map[parent_model_type] = max(aggregated_parent_status_map.get(parent_model_type, MODEL_FILE_STATUS_MISSING), parent_status)
|
| return direct_status_map, aggregated_parent_status_map
|
|
|
|
|
| def get_model_download_status_map(deps, dropdown_types=None):
|
| return get_model_download_status_maps(deps, dropdown_types)[1]
|
|
|
|
|
| def create_models_hierarchy(rows):
|
| """
|
| rows: list of (model_name, model_id, parent_model_id)
|
| returns:
|
| parents_list: list[(parent_header, parent_id)]
|
| children_dict: dict[parent_id] -> list[(child_display_name, child_id)]
|
| """
|
| toks = lambda s: [t for t in s.split() if t]
|
| norm = lambda s: " ".join(s.split()).casefold()
|
|
|
| groups, parents, order = defaultdict(list), {}, []
|
| for name, mid, pmid in rows:
|
| groups[pmid].append((name, mid))
|
| if mid == pmid and pmid not in parents:
|
| parents[pmid] = name
|
| order.append(pmid)
|
|
|
| parents_list, children_dict = [], {}
|
|
|
| for pid in order:
|
| p_name = parents[pid]
|
| p_tok = toks(p_name)
|
| p_low = [w.casefold() for w in p_tok]
|
| n = len(p_low)
|
| p_last = p_low[-1]
|
| p_set = set(p_low)
|
|
|
| kids = []
|
| for name, mid in groups.get(pid, []):
|
| ot = toks(name)
|
| lt = [w.casefold() for w in ot]
|
| st = set(lt)
|
| kids.append((name, mid, ot, lt, st))
|
|
|
| outliers = {mid for _, mid, _, _, st in kids if mid != pid and p_set.isdisjoint(st)}
|
|
|
| prefix_non = []
|
| for name, mid, ot, lt, st in kids:
|
| if mid == pid or (mid not in outliers and lt and lt[0] == p_low[0]):
|
| prefix_non.append((ot, lt))
|
|
|
| def lcp_len(a, b):
|
| i = 0
|
| m = min(len(a), len(b))
|
| while i < m and a[i] == b[i]:
|
| i += 1
|
| return i
|
|
|
| L = n if len(prefix_non) <= 1 else min(lcp_len(lt, p_low) for _, lt in prefix_non)
|
| if L == 0 and len(prefix_non) > 1:
|
| L = n
|
|
|
| shares_last = any(mid != pid and mid not in outliers and lt and lt[-1] == p_last for _, mid, _, lt, _ in kids)
|
| header_tokens_disp = p_tok[:L] + ([p_tok[-1]] if shares_last and L < n else [])
|
| header = " ".join(header_tokens_disp)
|
| header_has_last = (L == n) or (shares_last and L < n)
|
|
|
| prefix_low = p_low[:L]
|
|
|
| def startswith_prefix(lt):
|
| if L == 0 or len(lt) < L:
|
| return False
|
| for i in range(L):
|
| if lt[i] != prefix_low[i]:
|
| return False
|
| return True
|
|
|
| def base_rem(ot, lt):
|
| return ot[L:] if startswith_prefix(lt) else ot[:]
|
|
|
| def trim_rem(rem, lt):
|
| out = rem[:]
|
| if header_has_last and lt and lt[-1] == p_last and out and out[-1].casefold() == p_last:
|
| out = out[:-1]
|
| return out
|
|
|
| kid_infos = []
|
| for name, mid, ot, lt, _ in kids:
|
| rem_core = base_rem(ot, lt) if mid not in outliers else ot[:]
|
| kid_infos.append({
|
| "name": name,
|
| "mid": mid,
|
| "ot": ot,
|
| "lt": lt,
|
| "outlier": mid in outliers,
|
| "rem_core": rem_core,
|
| "rem_trim": trim_rem(rem_core, lt) if mid not in outliers else ot[:],
|
| "rem_set": {w.casefold() for w in rem_core} if mid not in outliers else set(),
|
| "rem_trim_set": {w.casefold() for w in (trim_rem(rem_core, lt) if mid not in outliers else ot[:])} if mid not in outliers else set(),
|
| })
|
|
|
| default_info = next(info for info in kid_infos if info["mid"] == pid)
|
| other_words = set()
|
| for info in kid_infos:
|
| if info["mid"] != pid:
|
| other_words |= info["rem_set"]
|
| default_shares = bool(default_info["rem_set"] & other_words)
|
|
|
| def disp(info):
|
| if info["outlier"]:
|
| return info["name"]
|
| if info["mid"] == pid:
|
| if not default_shares:
|
| return "Default"
|
| rem = info["rem_trim"]
|
| else:
|
| rem = info["rem_trim"]
|
| s = " ".join(rem).strip()
|
| return s if s else "Default"
|
|
|
| entries = [(disp(default_info), pid)]
|
| for info in kid_infos:
|
| if info["mid"] == pid:
|
| continue
|
| entries.append((disp(info), info["mid"]))
|
|
|
| p_full = norm(p_name)
|
| full_by_mid = {mid: name for name, mid, *_ in kids}
|
| num = 2
|
| numbered = [entries[0]]
|
| for dname, mid in entries[1:]:
|
| if dname == "Default" and norm(full_by_mid[mid]) == p_full:
|
| numbered.append((f"Default #{num}", mid))
|
| num += 1
|
| else:
|
| numbered.append((dname, mid))
|
|
|
| parents_list.append((header, pid))
|
| children_dict[pid] = numbered
|
|
|
| for pid in groups.keys():
|
| if pid in parents:
|
| continue
|
| first_name = groups[pid][0][0]
|
| parents_list.append((first_name, pid))
|
| children_dict[pid] = [(name, mid) for name, mid in groups[pid]]
|
|
|
| parents_list = sorted(parents_list, key=lambda c: c[0])
|
| return parents_list, children_dict
|
|
|
|
|
| def get_sorted_dropdown(deps, dropdown_types, current_model_family, current_model_type, three_levels=True):
|
| models_families = [deps.get_model_family(t, for_ui=True) for t in dropdown_types]
|
| families = {}
|
| for family in models_families:
|
| if family not in families:
|
| families[family] = 1
|
|
|
| families_orders = [deps.families_infos[family][0] for family in families]
|
| families_labels = [deps.families_infos[family][1] for family in families]
|
| sorted_familes = [info[1:] for info in sorted(zip(families_orders, families_labels, families), key=lambda c: c[0])]
|
| if current_model_family is None:
|
| dropdown_choices = [(deps.families_infos[family][0], deps.get_model_name(model_type), model_type) for model_type, family in zip(dropdown_types, models_families)]
|
| else:
|
| dropdown_choices = [(deps.families_infos[family][0], compact_name(deps.families_infos[family][1], deps.get_model_name(model_type)), model_type) for model_type, family in zip(dropdown_types, models_families) if family == current_model_family]
|
| dropdown_choices = sorted(dropdown_choices, key=lambda c: (c[0], c[1]))
|
| if three_levels:
|
| dropdown_choices = [(*model[1:], deps.get_parent_model_type(model[2])) for model in dropdown_choices]
|
| sorted_choices, finetunes_dict = create_models_hierarchy(dropdown_choices)
|
| return sorted_familes, sorted_choices, finetunes_dict[deps.get_parent_model_type(current_model_type)]
|
| dropdown_types_list = list({deps.get_base_model_type(model[2]) for model in dropdown_choices})
|
| dropdown_choices = [model[1:] for model in dropdown_choices]
|
| return sorted_familes, dropdown_types_list, dropdown_choices
|
|
|
|
|
| def generate_dropdown_model_list(deps, current_model_type):
|
| dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types)
|
| if current_model_type not in dropdown_types:
|
| dropdown_types.append(current_model_type)
|
| current_model_family = deps.get_model_family(current_model_type, for_ui=True)
|
| sorted_familes, sorted_models, sorted_finetunes = get_sorted_dropdown(deps, dropdown_types, current_model_family, current_model_type, three_levels=deps.three_levels_hierarchy)
|
| status_model_types = get_family_dropdown_model_types(deps, current_model_family, dropdown_types)
|
| if current_model_type not in status_model_types:
|
| status_model_types.append(current_model_type)
|
| direct_status_map, aggregated_parent_status_map = get_model_download_status_maps(deps, status_model_types)
|
| sorted_models = decorate_dropdown_choices_with_status(sorted_models, aggregated_parent_status_map)
|
| sorted_finetunes = decorate_dropdown_choices_with_status(sorted_finetunes, direct_status_map)
|
|
|
| dropdown_families = gr.Dropdown(choices=sorted_familes, value=current_model_family, show_label=False, scale=2 if deps.three_levels_hierarchy else 1, elem_id="family_list", min_width=50)
|
| dropdown_models = gr.Dropdown(choices=sorted_models, value=deps.get_parent_model_type(current_model_type) if deps.three_levels_hierarchy else deps.get_base_model_type(current_model_type), show_label=False, scale=3 if len(sorted_finetunes) > 1 else 7, elem_id="model_base_types_list", visible=deps.three_levels_hierarchy)
|
| dropdown_finetunes = gr.Dropdown(choices=sorted_finetunes, value=current_model_type, show_label=False, scale=4, visible=len(sorted_finetunes) > 1 or not deps.three_levels_hierarchy, elem_id="model_list")
|
| return dropdown_families, dropdown_models, dropdown_finetunes
|
|
|
|
|
| def change_model_family(deps, state, current_model_family):
|
| dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types)
|
| current_family_name = deps.families_infos[current_model_family][1]
|
| models_families = [deps.get_model_family(t, for_ui=True) for t in dropdown_types]
|
| dropdown_choices = [(compact_name(current_family_name, deps.get_model_name(model_type)), model_type) for model_type, family in zip(dropdown_types, models_families) if family == current_model_family]
|
| dropdown_choices = sorted(dropdown_choices, key=lambda c: c[0])
|
| family_dropdown_types = [choice[1] for choice in dropdown_choices]
|
| direct_status_map, aggregated_parent_status_map = get_model_download_status_maps(deps, family_dropdown_types)
|
| last_model_per_family = state.get("last_model_per_family", {})
|
| model_type = last_model_per_family.get(current_model_family, "")
|
| if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices]:
|
| model_type = dropdown_choices[0][1]
|
|
|
| if deps.three_levels_hierarchy:
|
| parent_model_type = deps.get_parent_model_type(model_type)
|
| dropdown_choices = [(*tup, deps.get_parent_model_type(tup[1])) for tup in dropdown_choices]
|
| dropdown_base_types_choices, finetunes_dict = create_models_hierarchy(dropdown_choices)
|
| dropdown_choices = decorate_dropdown_choices_with_status(finetunes_dict[parent_model_type], direct_status_map)
|
| dropdown_base_types_choices = decorate_dropdown_choices_with_status(dropdown_base_types_choices, aggregated_parent_status_map)
|
| model_finetunes_visible = len(dropdown_choices) > 1
|
| else:
|
| parent_model_type = deps.get_base_model_type(model_type)
|
| model_finetunes_visible = True
|
| dropdown_base_types_choices = list({deps.get_base_model_type(model[1]) for model in dropdown_choices})
|
| dropdown_choices = decorate_dropdown_choices_with_status(dropdown_choices, direct_status_map)
|
|
|
| return gr.Dropdown(choices=dropdown_base_types_choices, value=parent_model_type, scale=3 if model_finetunes_visible else 7), gr.Dropdown(choices=dropdown_choices, value=model_type, visible=model_finetunes_visible)
|
|
|
|
|
| def change_model_base_types(deps, state, current_model_family, model_base_type_choice):
|
| if not deps.three_levels_hierarchy:
|
| return gr.update()
|
| dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types)
|
| current_family_name = deps.families_infos[current_model_family][1]
|
| dropdown_choices = [(compact_name(current_family_name, deps.get_model_name(model_type)), model_type, model_base_type_choice) for model_type in dropdown_types if deps.get_parent_model_type(model_type) == model_base_type_choice and deps.get_model_family(model_type, for_ui=True) == current_model_family]
|
| dropdown_choices = sorted(dropdown_choices, key=lambda c: c[0])
|
| _, finetunes_dict = create_models_hierarchy(dropdown_choices)
|
| base_dropdown_types = [choice[1] for choice in dropdown_choices]
|
| direct_status_map, _ = get_model_download_status_maps(deps, base_dropdown_types)
|
| dropdown_choices = decorate_dropdown_choices_with_status(finetunes_dict[model_base_type_choice], direct_status_map)
|
| model_finetunes_visible = len(dropdown_choices) > 1
|
| last_model_per_type = state.get("last_model_per_type", {})
|
| model_type = last_model_per_type.get(model_base_type_choice, "")
|
| if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices]:
|
| model_type = dropdown_choices[0][1]
|
| return gr.update(scale=3 if model_finetunes_visible else 7), gr.Dropdown(choices=dropdown_choices, value=model_type, visible=model_finetunes_visible)
|
|
|