File size: 26,964 Bytes
4689c2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
import os
from collections import defaultdict
from dataclasses import dataclass

import gradio as gr


MODEL_FILE_STATUS_MISSING = 0
MODEL_FILE_STATUS_PARTIAL = 1
MODEL_FILE_STATUS_EXPECTED = 2
MODEL_STATUS_PREFIXES = {
    MODEL_FILE_STATUS_MISSING: "\u2B1B",
    MODEL_FILE_STATUS_EXPECTED: "\U0001F7E6",
    MODEL_FILE_STATUS_PARTIAL: "\U0001F7E8",
}


@dataclass
class DropdownDeps:
    transformer_types: list
    displayed_model_types: list
    transformer_type: str
    three_levels_hierarchy: bool
    families_infos: dict
    server_config: dict
    transformer_quantization: str
    transformer_dtype_policy: str
    text_encoder_quantization: str
    get_model_def: callable
    get_model_recursive_prop: callable
    get_model_filename: callable
    get_local_model_filename: callable
    get_lora_dir: callable
    get_parent_model_type: callable
    get_base_model_type: callable
    get_model_family: callable
    get_model_name: callable
    get_transformer_dtype: callable


def compact_name(family_name, model_name):
    if model_name.startswith(family_name):
        return model_name[len(family_name):].strip()
    return model_name


def decorate_model_dropdown_label(label, status):
    if not isinstance(label, str):
        return label
    prefix = MODEL_STATUS_PREFIXES.get(status, "")
    return f"{prefix} {label}" if len(prefix) > 0 else label


def decorate_dropdown_choices_with_status(choices, status_map):
    decorated = []
    for choice in choices:
        if not isinstance(choice, tuple) or len(choice) < 2:
            decorated.append(choice)
            continue
        label, value = choice[0], choice[1]
        status = status_map.get(value, MODEL_FILE_STATUS_MISSING)
        decorated.append((decorate_model_dropdown_label(label, status), value, *choice[2:]))
    return decorated


def get_dropdown_model_types(deps):
    dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types)
    if deps.transformer_type not in dropdown_types:
        dropdown_types.append(deps.transformer_type)
    return list(dict.fromkeys(dropdown_types))


def get_family_dropdown_model_types(deps, current_model_family, dropdown_types=None):
    dropdown_types = get_dropdown_model_types(deps) if dropdown_types is None else dropdown_types
    if current_model_family is None:
        return dropdown_types
    return [model_type for model_type in dropdown_types if deps.get_model_family(model_type, for_ui=True) == current_model_family]


def _get_module_files_for_status(deps, model_type, quantization, dtype_policy):
    transformer_dtype = deps.get_transformer_dtype(model_type, dtype_policy)
    modules = deps.get_model_recursive_prop(model_type, "modules", return_list=True)
    modules = [deps.get_model_recursive_prop(module, "modules", sub_prop_name="_list", return_list=True) if isinstance(module, str) else module for module in modules]
    module_files = []
    for module_type in modules:
        if isinstance(module_type, dict):
            URLs1 = module_type.get("URLs", None)
            if URLs1 is None:
                return None
            module_files.append(deps.get_model_filename(model_type, quantization, transformer_dtype, URLs=URLs1))
            URLs2 = module_type.get("URLs2", None)
            if URLs2 is None:
                return None
            module_files.append(deps.get_model_filename(model_type, quantization, transformer_dtype, URLs=URLs2))
        else:
            module_files.append(deps.get_model_filename(model_type, quantization, transformer_dtype, module_type=module_type))
    return module_files


def _get_status_quantization_and_dtype(deps):
    quantization = deps.server_config.get("transformer_quantization", deps.transformer_quantization)
    dtype_policy = deps.server_config.get("transformer_dtype_policy", deps.transformer_dtype_policy)
    return quantization, dtype_policy


def _append_expected_file_entry(entries, seen, filename, extra_paths=None):
    if not isinstance(filename, str) or len(filename) == 0:
        return
    if extra_paths is None:
        extra_list = []
    elif isinstance(extra_paths, list):
        extra_list = [path for path in extra_paths if isinstance(path, str) and len(path) > 0]
    else:
        extra_list = [extra_paths] if isinstance(extra_paths, str) and len(extra_paths) > 0 else []
    key = (filename.casefold(), tuple(path.casefold() for path in extra_list))
    if key in seen:
        return
    seen.add(key)
    entries.append({"filename": filename, "extra_paths": extra_list if len(extra_list) > 0 else None})


def _append_expected_local_path_entry(entries, seen, local_path):
    if not isinstance(local_path, str) or len(local_path) == 0:
        return
    path_key = local_path.casefold()
    if path_key in seen:
        return
    seen.add(path_key)
    entries.append({"path": local_path})


def get_expected_core_file_entries_for_status(deps, model_type):
    model_def = deps.get_model_def(model_type)
    if model_def is None:
        return []
    quantization, dtype_policy = _get_status_quantization_and_dtype(deps)
    entries = []
    seen = set()

    expected_filename = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy)
    _append_expected_file_entry(entries, seen, expected_filename)
    if isinstance(model_def, dict) and "URLs2" in model_def:
        expected_filename2 = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy, submodel_no=2)
        _append_expected_file_entry(entries, seen, expected_filename2)

    module_files = _get_module_files_for_status(deps, model_type, quantization, dtype_policy)
    if isinstance(module_files, list):
        for filename in module_files:
            _append_expected_file_entry(entries, seen, filename)

    text_encoder_URLs = deps.get_model_recursive_prop(model_type, "text_encoder_URLs", return_list=True)
    if text_encoder_URLs is not None:
        text_encoder_filename = deps.get_model_filename(model_type=model_type, quantization=deps.text_encoder_quantization, dtype_policy=dtype_policy, URLs=text_encoder_URLs)
        text_encoder_folder = model_def.get("text_encoder_folder", None)
        _append_expected_file_entry(entries, seen, text_encoder_filename, extra_paths=text_encoder_folder)
    return entries


def get_missing_core_file_entries_for_status(deps, model_type):
    missing_entries = []
    for entry in get_expected_core_file_entries_for_status(deps, model_type):
        filename = entry.get("filename", "")
        extra_paths = entry.get("extra_paths", None)
        if deps.get_local_model_filename(filename, extra_paths=extra_paths) is None:
            missing_entries.append(entry)
    return missing_entries


def get_expected_secondary_file_entries_for_status(deps, model_type):
    model_def = deps.get_model_def(model_type)
    if model_def is None:
        return []
    entries = []
    seen = set()

    preload_urls = deps.get_model_recursive_prop(model_type, "preload_URLs", return_list=True)
    if preload_urls is None:
        preload_urls = []
    if not isinstance(preload_urls, list):
        preload_urls = [preload_urls]
    for url in preload_urls:
        if isinstance(url, str) and len(url) > 0:
            _append_expected_file_entry(entries, seen, url)

    vae_urls = model_def.get("VAE_URLs", [])
    if vae_urls is None:
        vae_urls = []
    if not isinstance(vae_urls, list):
        vae_urls = [vae_urls]
    for url in vae_urls:
        if isinstance(url, str) and len(url) > 0:
            _append_expected_file_entry(entries, seen, url)

    model_loras = deps.get_model_recursive_prop(model_type, "loras", return_list=True)
    if model_loras is None:
        model_loras = []
    if not isinstance(model_loras, list):
        model_loras = [model_loras]
    lora_dir = deps.get_lora_dir(model_type)
    for url in model_loras:
        if not isinstance(url, str) or len(url) == 0:
            continue
        basename = os.path.basename(url)
        if len(basename) == 0:
            continue
        _append_expected_local_path_entry(entries, seen, os.path.join(lora_dir, basename))

    return entries


def has_secondary_model_files_for_status(deps, model_type, quantization, dtype_policy):
    model_def = deps.get_model_def(model_type)
    if model_def is None:
        return True

    text_encoder_URLs = deps.get_model_recursive_prop(model_type, "text_encoder_URLs", return_list=True)
    if text_encoder_URLs is not None:
        text_encoder_filename = deps.get_model_filename(model_type=model_type, quantization=deps.text_encoder_quantization, dtype_policy=dtype_policy, URLs=text_encoder_URLs)
        if isinstance(text_encoder_filename, str) and len(text_encoder_filename) > 0:
            text_encoder_folder = model_def.get("text_encoder_folder", None)
            if deps.get_local_model_filename(text_encoder_filename, extra_paths=text_encoder_folder) is None:
                return False

    for prop, recursive in (("preload_URLs", True), ("VAE_URLs", False)):
        if recursive:
            urls = deps.get_model_recursive_prop(model_type, prop, return_list=True)
        else:
            urls = model_def.get(prop, [])
        if urls is None:
            continue
        if not isinstance(urls, list):
            urls = [urls]
        for url in urls:
            if not isinstance(url, str) or len(url) == 0:
                continue
            if deps.get_local_model_filename(url) is None:
                return False

    model_loras = deps.get_model_recursive_prop(model_type, "loras", return_list=True)
    if model_loras is None:
        model_loras = []
    if not isinstance(model_loras, list):
        model_loras = [model_loras]
    lora_dir = deps.get_lora_dir(model_type)
    for url in model_loras:
        if not isinstance(url, str) or len(url) == 0:
            continue
        if not os.path.isfile(os.path.join(lora_dir, os.path.basename(url))):
            return False

    module_files = _get_module_files_for_status(deps, model_type, quantization, dtype_policy)
    if module_files is None:
        return False
    for filename in module_files:
        if not isinstance(filename, str) or len(filename) == 0:
            continue
        if deps.get_local_model_filename(filename) is None:
            return False
    return True


def get_model_download_status(deps, model_type):
    quantization, dtype_policy = _get_status_quantization_and_dtype(deps)
    model_def = deps.get_model_def(model_type)
    expected_filenames = []
    expected_filename = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy)
    if isinstance(expected_filename, str) and len(expected_filename) > 0:
        expected_filenames.append(expected_filename)
    if isinstance(model_def, dict) and "URLs2" in model_def:
        expected_filename2 = deps.get_model_filename(model_type, quantization=quantization, dtype_policy=dtype_policy, submodel_no=2)
        if isinstance(expected_filename2, str) and len(expected_filename2) > 0:
            expected_filenames.append(expected_filename2)

    expected_exists = []
    for filename in expected_filenames:
        expected_exists.append(deps.get_local_model_filename(filename) is not None)

    if len(expected_exists) > 0 and all(expected_exists):
        if not has_secondary_model_files_for_status(deps, model_type, quantization, dtype_policy):
            return MODEL_FILE_STATUS_PARTIAL
        return MODEL_FILE_STATUS_EXPECTED

    if any(expected_exists):
        return MODEL_FILE_STATUS_PARTIAL

    candidate_urls = []
    for prop in ("URLs", "URLs2"):
        urls = deps.get_model_recursive_prop(model_type, prop, return_list=True)
        if not isinstance(urls, list):
            urls = [urls] if urls else []
        candidate_urls += urls

    checked_candidates = set()
    expected_set = {name.casefold() for name in expected_filenames if isinstance(name, str) and len(name) > 0}
    for candidate in candidate_urls:
        if not isinstance(candidate, str) or len(candidate) == 0:
            continue
        candidate_key = candidate.casefold()
        if candidate_key in checked_candidates:
            continue
        checked_candidates.add(candidate_key)
        if candidate_key in expected_set:
            continue
        if deps.get_local_model_filename(candidate) is not None:
            return MODEL_FILE_STATUS_PARTIAL
    return MODEL_FILE_STATUS_MISSING


def get_model_download_status_maps(deps, dropdown_types=None):
    direct_status_map = {}
    dropdown_types = get_dropdown_model_types(deps) if dropdown_types is None else dropdown_types
    parent_to_children = defaultdict(list)

    for model_type in dropdown_types:
        if deps.get_model_def(model_type) is None:
            continue
        status = get_model_download_status(deps, model_type)
        direct_status_map[model_type] = status
        parent_model_type = deps.get_parent_model_type(model_type)
        if parent_model_type is not None:
            parent_to_children[parent_model_type].append(model_type)

    aggregated_parent_status_map = dict(direct_status_map)
    for parent_model_type, children in parent_to_children.items():
        child_statuses = [direct_status_map.get(child, MODEL_FILE_STATUS_MISSING) for child in children]
        if len(child_statuses) == 0:
            continue
        parent_status = MODEL_FILE_STATUS_MISSING
        if any(status == MODEL_FILE_STATUS_EXPECTED for status in child_statuses):
            parent_status = MODEL_FILE_STATUS_EXPECTED
        elif any(status == MODEL_FILE_STATUS_PARTIAL for status in child_statuses):
            parent_status = MODEL_FILE_STATUS_PARTIAL
        aggregated_parent_status_map[parent_model_type] = max(aggregated_parent_status_map.get(parent_model_type, MODEL_FILE_STATUS_MISSING), parent_status)
    return direct_status_map, aggregated_parent_status_map


def get_model_download_status_map(deps, dropdown_types=None):
    return get_model_download_status_maps(deps, dropdown_types)[1]


def create_models_hierarchy(rows):
    """

    rows: list of (model_name, model_id, parent_model_id)

    returns:

      parents_list: list[(parent_header, parent_id)]

      children_dict: dict[parent_id] -> list[(child_display_name, child_id)]

    """
    toks = lambda s: [t for t in s.split() if t]
    norm = lambda s: " ".join(s.split()).casefold()

    groups, parents, order = defaultdict(list), {}, []
    for name, mid, pmid in rows:
        groups[pmid].append((name, mid))
        if mid == pmid and pmid not in parents:
            parents[pmid] = name
            order.append(pmid)

    parents_list, children_dict = [], {}

    for pid in order:
        p_name = parents[pid]
        p_tok = toks(p_name)
        p_low = [w.casefold() for w in p_tok]
        n = len(p_low)
        p_last = p_low[-1]
        p_set = set(p_low)

        kids = []
        for name, mid in groups.get(pid, []):
            ot = toks(name)
            lt = [w.casefold() for w in ot]
            st = set(lt)
            kids.append((name, mid, ot, lt, st))

        outliers = {mid for _, mid, _, _, st in kids if mid != pid and p_set.isdisjoint(st)}

        prefix_non = []
        for name, mid, ot, lt, st in kids:
            if mid == pid or (mid not in outliers and lt and lt[0] == p_low[0]):
                prefix_non.append((ot, lt))

        def lcp_len(a, b):
            i = 0
            m = min(len(a), len(b))
            while i < m and a[i] == b[i]:
                i += 1
            return i

        L = n if len(prefix_non) <= 1 else min(lcp_len(lt, p_low) for _, lt in prefix_non)
        if L == 0 and len(prefix_non) > 1:
            L = n

        shares_last = any(mid != pid and mid not in outliers and lt and lt[-1] == p_last for _, mid, _, lt, _ in kids)
        header_tokens_disp = p_tok[:L] + ([p_tok[-1]] if shares_last and L < n else [])
        header = " ".join(header_tokens_disp)
        header_has_last = (L == n) or (shares_last and L < n)

        prefix_low = p_low[:L]

        def startswith_prefix(lt):
            if L == 0 or len(lt) < L:
                return False
            for i in range(L):
                if lt[i] != prefix_low[i]:
                    return False
            return True

        def base_rem(ot, lt):
            return ot[L:] if startswith_prefix(lt) else ot[:]

        def trim_rem(rem, lt):
            out = rem[:]
            if header_has_last and lt and lt[-1] == p_last and out and out[-1].casefold() == p_last:
                out = out[:-1]
            return out

        kid_infos = []
        for name, mid, ot, lt, _ in kids:
            rem_core = base_rem(ot, lt) if mid not in outliers else ot[:]
            kid_infos.append({
                "name": name,
                "mid": mid,
                "ot": ot,
                "lt": lt,
                "outlier": mid in outliers,
                "rem_core": rem_core,
                "rem_trim": trim_rem(rem_core, lt) if mid not in outliers else ot[:],
                "rem_set": {w.casefold() for w in rem_core} if mid not in outliers else set(),
                "rem_trim_set": {w.casefold() for w in (trim_rem(rem_core, lt) if mid not in outliers else ot[:])} if mid not in outliers else set(),
            })

        default_info = next(info for info in kid_infos if info["mid"] == pid)
        other_words = set()
        for info in kid_infos:
            if info["mid"] != pid:
                other_words |= info["rem_set"]
        default_shares = bool(default_info["rem_set"] & other_words)

        def disp(info):
            if info["outlier"]:
                return info["name"]
            if info["mid"] == pid:
                if not default_shares:
                    return "Default"
                rem = info["rem_trim"]
            else:
                rem = info["rem_trim"]
            s = " ".join(rem).strip()
            return s if s else "Default"

        entries = [(disp(default_info), pid)]
        for info in kid_infos:
            if info["mid"] == pid:
                continue
            entries.append((disp(info), info["mid"]))

        p_full = norm(p_name)
        full_by_mid = {mid: name for name, mid, *_ in kids}
        num = 2
        numbered = [entries[0]]
        for dname, mid in entries[1:]:
            if dname == "Default" and norm(full_by_mid[mid]) == p_full:
                numbered.append((f"Default #{num}", mid))
                num += 1
            else:
                numbered.append((dname, mid))

        parents_list.append((header, pid))
        children_dict[pid] = numbered

    for pid in groups.keys():
        if pid in parents:
            continue
        first_name = groups[pid][0][0]
        parents_list.append((first_name, pid))
        children_dict[pid] = [(name, mid) for name, mid in groups[pid]]

    parents_list = sorted(parents_list, key=lambda c: c[0])
    return parents_list, children_dict


def get_sorted_dropdown(deps, dropdown_types, current_model_family, current_model_type, three_levels=True):
    models_families = [deps.get_model_family(t, for_ui=True) for t in dropdown_types]
    families = {}
    for family in models_families:
        if family not in families:
            families[family] = 1

    families_orders = [deps.families_infos[family][0] for family in families]
    families_labels = [deps.families_infos[family][1] for family in families]
    sorted_familes = [info[1:] for info in sorted(zip(families_orders, families_labels, families), key=lambda c: c[0])]
    if current_model_family is None:
        dropdown_choices = [(deps.families_infos[family][0], deps.get_model_name(model_type), model_type) for model_type, family in zip(dropdown_types, models_families)]
    else:
        dropdown_choices = [(deps.families_infos[family][0], compact_name(deps.families_infos[family][1], deps.get_model_name(model_type)), model_type) for model_type, family in zip(dropdown_types, models_families) if family == current_model_family]
    dropdown_choices = sorted(dropdown_choices, key=lambda c: (c[0], c[1]))
    if three_levels:
        dropdown_choices = [(*model[1:], deps.get_parent_model_type(model[2])) for model in dropdown_choices]
        sorted_choices, finetunes_dict = create_models_hierarchy(dropdown_choices)
        return sorted_familes, sorted_choices, finetunes_dict[deps.get_parent_model_type(current_model_type)]
    dropdown_types_list = list({deps.get_base_model_type(model[2]) for model in dropdown_choices})
    dropdown_choices = [model[1:] for model in dropdown_choices]
    return sorted_familes, dropdown_types_list, dropdown_choices


def generate_dropdown_model_list(deps, current_model_type):
    dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types)
    if current_model_type not in dropdown_types:
        dropdown_types.append(current_model_type)
    current_model_family = deps.get_model_family(current_model_type, for_ui=True)
    sorted_familes, sorted_models, sorted_finetunes = get_sorted_dropdown(deps, dropdown_types, current_model_family, current_model_type, three_levels=deps.three_levels_hierarchy)
    status_model_types = get_family_dropdown_model_types(deps, current_model_family, dropdown_types)
    if current_model_type not in status_model_types:
        status_model_types.append(current_model_type)
    direct_status_map, aggregated_parent_status_map = get_model_download_status_maps(deps, status_model_types)
    sorted_models = decorate_dropdown_choices_with_status(sorted_models, aggregated_parent_status_map)
    sorted_finetunes = decorate_dropdown_choices_with_status(sorted_finetunes, direct_status_map)

    dropdown_families = gr.Dropdown(choices=sorted_familes, value=current_model_family, show_label=False, scale=2 if deps.three_levels_hierarchy else 1, elem_id="family_list", min_width=50)
    dropdown_models = gr.Dropdown(choices=sorted_models, value=deps.get_parent_model_type(current_model_type) if deps.three_levels_hierarchy else deps.get_base_model_type(current_model_type), show_label=False, scale=3 if len(sorted_finetunes) > 1 else 7, elem_id="model_base_types_list", visible=deps.three_levels_hierarchy)
    dropdown_finetunes = gr.Dropdown(choices=sorted_finetunes, value=current_model_type, show_label=False, scale=4, visible=len(sorted_finetunes) > 1 or not deps.three_levels_hierarchy, elem_id="model_list")
    return dropdown_families, dropdown_models, dropdown_finetunes


def change_model_family(deps, state, current_model_family):
    dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types)
    current_family_name = deps.families_infos[current_model_family][1]
    models_families = [deps.get_model_family(t, for_ui=True) for t in dropdown_types]
    dropdown_choices = [(compact_name(current_family_name, deps.get_model_name(model_type)), model_type) for model_type, family in zip(dropdown_types, models_families) if family == current_model_family]
    dropdown_choices = sorted(dropdown_choices, key=lambda c: c[0])
    family_dropdown_types = [choice[1] for choice in dropdown_choices]
    direct_status_map, aggregated_parent_status_map = get_model_download_status_maps(deps, family_dropdown_types)
    last_model_per_family = state.get("last_model_per_family", {})
    model_type = last_model_per_family.get(current_model_family, "")
    if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices]:
        model_type = dropdown_choices[0][1]

    if deps.three_levels_hierarchy:
        parent_model_type = deps.get_parent_model_type(model_type)
        dropdown_choices = [(*tup, deps.get_parent_model_type(tup[1])) for tup in dropdown_choices]
        dropdown_base_types_choices, finetunes_dict = create_models_hierarchy(dropdown_choices)
        dropdown_choices = decorate_dropdown_choices_with_status(finetunes_dict[parent_model_type], direct_status_map)
        dropdown_base_types_choices = decorate_dropdown_choices_with_status(dropdown_base_types_choices, aggregated_parent_status_map)
        model_finetunes_visible = len(dropdown_choices) > 1
    else:
        parent_model_type = deps.get_base_model_type(model_type)
        model_finetunes_visible = True
        dropdown_base_types_choices = list({deps.get_base_model_type(model[1]) for model in dropdown_choices})
        dropdown_choices = decorate_dropdown_choices_with_status(dropdown_choices, direct_status_map)

    return gr.Dropdown(choices=dropdown_base_types_choices, value=parent_model_type, scale=3 if model_finetunes_visible else 7), gr.Dropdown(choices=dropdown_choices, value=model_type, visible=model_finetunes_visible)


def change_model_base_types(deps, state, current_model_family, model_base_type_choice):
    if not deps.three_levels_hierarchy:
        return gr.update()
    dropdown_types = list(deps.transformer_types) if len(deps.transformer_types) > 0 else list(deps.displayed_model_types)
    current_family_name = deps.families_infos[current_model_family][1]
    dropdown_choices = [(compact_name(current_family_name, deps.get_model_name(model_type)), model_type, model_base_type_choice) for model_type in dropdown_types if deps.get_parent_model_type(model_type) == model_base_type_choice and deps.get_model_family(model_type, for_ui=True) == current_model_family]
    dropdown_choices = sorted(dropdown_choices, key=lambda c: c[0])
    _, finetunes_dict = create_models_hierarchy(dropdown_choices)
    base_dropdown_types = [choice[1] for choice in dropdown_choices]
    direct_status_map, _ = get_model_download_status_maps(deps, base_dropdown_types)
    dropdown_choices = decorate_dropdown_choices_with_status(finetunes_dict[model_base_type_choice], direct_status_map)
    model_finetunes_visible = len(dropdown_choices) > 1
    last_model_per_type = state.get("last_model_per_type", {})
    model_type = last_model_per_type.get(model_base_type_choice, "")
    if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices]:
        model_type = dropdown_choices[0][1]
    return gr.update(scale=3 if model_finetunes_visible else 7), gr.Dropdown(choices=dropdown_choices, value=model_type, visible=model_finetunes_visible)