Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import re | |
| import torch | |
| from packaging import version | |
| from tgs.utils.typing import * | |
| def parse_version(ver: str): | |
| return version.parse(ver) | |
| def get_rank(): | |
| # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, | |
| # therefore LOCAL_RANK needs to be checked first | |
| rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") | |
| for key in rank_keys: | |
| rank = os.environ.get(key) | |
| if rank is not None: | |
| return int(rank) | |
| return 0 | |
| def get_device(): | |
| return torch.device(f"cuda:{get_rank()}") | |
| def load_module_weights( | |
| path, module_name=None, ignore_modules=None, map_location=None | |
| ) -> Tuple[dict, int, int]: | |
| if module_name is not None and ignore_modules is not None: | |
| raise ValueError("module_name and ignore_modules cannot be both set") | |
| if map_location is None: | |
| map_location = get_device() | |
| ckpt = torch.load(path, map_location=map_location) | |
| state_dict = ckpt["state_dict"] | |
| state_dict_to_load = state_dict | |
| if ignore_modules is not None: | |
| state_dict_to_load = {} | |
| for k, v in state_dict.items(): | |
| ignore = any( | |
| [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] | |
| ) | |
| if ignore: | |
| continue | |
| state_dict_to_load[k] = v | |
| if module_name is not None: | |
| state_dict_to_load = {} | |
| for k, v in state_dict.items(): | |
| m = re.match(rf"^{module_name}\.(.*)$", k) | |
| if m is None: | |
| continue | |
| state_dict_to_load[m.group(1)] = v | |
| return state_dict_to_load | |
| # convert a function into recursive style to handle nested dict/list/tuple variables | |
| def make_recursive_func(func): | |
| def wrapper(vars, *args, **kwargs): | |
| if isinstance(vars, list): | |
| return [wrapper(x, *args, **kwargs) for x in vars] | |
| elif isinstance(vars, tuple): | |
| return tuple([wrapper(x, *args, **kwargs) for x in vars]) | |
| elif isinstance(vars, dict): | |
| return {k: wrapper(v, *args, **kwargs) for k, v in vars.items()} | |
| else: | |
| return func(vars, *args, **kwargs) | |
| return wrapper | |
| def todevice(vars, device="cuda"): | |
| if isinstance(vars, torch.Tensor): | |
| return vars.to(device) | |
| elif isinstance(vars, str): | |
| return vars | |
| elif isinstance(vars, bool): | |
| return vars | |
| elif isinstance(vars, float): | |
| return vars | |
| elif isinstance(vars, int): | |
| return vars | |
| else: | |
| raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) | |