| # ztrain/io.py | |
| # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted | |
| import os | |
| from glob import glob | |
| def flatten_index(model_paths : list[str], allow_list : list[str]): | |
| flat = [] | |
| subtype = [] | |
| index = {} | |
| ix = 0 | |
| for g in sorted(model_paths): | |
| name = os.path.basename(g) | |
| if name in allow_list: | |
| index[name] = ix | |
| flat.append(name) | |
| if 'base' in g: | |
| subtype.append('base') | |
| elif 'instruct' in g: | |
| subtype.append('instruct') | |
| else: | |
| subtype.append('other') | |
| ix += 1 | |
| return index, flat, subtype | |
| def list_for_path(path: str, include_folders: list[str], search: str = "/**/*") -> tuple[list[str], list[str], list[str], dict[str, int]]: | |
| model_list = sorted([*[ f for f in glob(path + search)]]) | |
| group_idx, model_names, subtypes = flatten_index(model_list, include_folders) | |
| groups = [[m] for m in model_names] | |
| return model_names, subtypes, model_list, group_idx |