Spaces:
Running
Running
| import os | |
| import sys | |
| import inspect | |
| sys.path.append(os.getcwd()) | |
| from main.library.speaker_diarization.speechbrain import fetch, run_on_main | |
| from main.library.speaker_diarization.features import DEFAULT_TRANSFER_HOOKS, DEFAULT_LOAD_HOOKS | |
| def get_default_hook(obj, default_hooks): | |
| for cls in inspect.getmro(type(obj)): | |
| if cls in default_hooks: return default_hooks[cls] | |
| return None | |
| class Pretrainer: | |
| def __init__(self, loadables=None, paths=None, custom_hooks=None, conditions=None): | |
| self.loadables = {} | |
| if loadables is not None: self.add_loadables(loadables) | |
| self.paths = {} | |
| if paths is not None: self.add_paths(paths) | |
| self.custom_hooks = {} | |
| if custom_hooks is not None: self.add_custom_hooks(custom_hooks) | |
| self.conditions = {} | |
| if conditions is not None: self.add_conditions(conditions) | |
| self.is_local = [] | |
| def add_loadables(self, loadables): | |
| self.loadables.update(loadables) | |
| def add_paths(self, paths): | |
| self.paths.update(paths) | |
| def add_custom_hooks(self, custom_hooks): | |
| self.custom_hooks.update(custom_hooks) | |
| def add_conditions(self, conditions): | |
| self.conditions.update(conditions) | |
| def split_path(path): | |
| def split(src): | |
| if "/" in src: return src.rsplit("/", maxsplit=1) | |
| else: return "./", src | |
| return split(path) | |
| def collect_files(self, default_source=None): | |
| loadable_paths = {} | |
| for name in self.loadables: | |
| if not self.is_loadable(name): continue | |
| save_filename = name + ".ckpt" | |
| if name in self.paths: source, filename = self.split_path(self.paths[name]) | |
| elif default_source is not None: | |
| filename = save_filename | |
| source = default_source | |
| else: raise ValueError | |
| fetch_kwargs = {"filename": filename, "source": source} | |
| path = None | |
| def run_fetch(**kwargs): | |
| nonlocal path | |
| path = fetch(**kwargs) | |
| run_on_main(run_fetch, kwargs=fetch_kwargs, post_func=run_fetch, post_kwargs=fetch_kwargs) | |
| loadable_paths[name] = path | |
| self.paths[name] = str(path) | |
| self.is_local.append(name) | |
| return loadable_paths | |
| def is_loadable(self, name): | |
| if name not in self.conditions: return True | |
| condition = self.conditions[name] | |
| if callable(condition): return condition() | |
| else: return bool(condition) | |
| def load_collected(self): | |
| paramfiles = {} | |
| for name in self.loadables: | |
| if not self.is_loadable(name): continue | |
| if name in self.is_local: paramfiles[name] = self.paths[name] | |
| else: raise ValueError | |
| self._call_load_hooks(paramfiles) | |
| def _call_load_hooks(self, paramfiles): | |
| for name, obj in self.loadables.items(): | |
| if not self.is_loadable(name): continue | |
| loadpath = paramfiles[name] | |
| if name in self.custom_hooks: | |
| self.custom_hooks[name](obj, loadpath) | |
| continue | |
| default_hook = get_default_hook(obj, DEFAULT_TRANSFER_HOOKS) | |
| if default_hook is not None: | |
| default_hook(obj, loadpath) | |
| continue | |
| default_hook = get_default_hook(obj, DEFAULT_LOAD_HOOKS) | |
| if default_hook is not None: | |
| end_of_epoch = False | |
| default_hook(obj, loadpath, end_of_epoch) | |
| continue | |
| raise RuntimeError |