| import gc |
| import os |
| import re |
|
|
| import torch |
| import torch.distributed as dist |
| from packaging import version |
|
|
| from craftsman.utils.config import config_to_primitive |
| from craftsman.utils.typing import * |
|
|
|
|
|
|
| def parse_version(ver: str): |
| return version.parse(ver) |
|
|
|
|
| def get_rank(): |
| |
| |
| rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") |
| for key in rank_keys: |
| rank = os.environ.get(key) |
| if rank is not None: |
| return int(rank) |
| return 0 |
|
|
| def get_world_size(): |
| world_size_keys = ("WORLD_SIZE", "SLURM_NTASKS", "JSM_NAMESPACE_SIZE") |
| for key in world_size_keys: |
| world_size = os.environ.get(key) |
| if world_size is not None: |
| return int(world_size) |
| return 1 |
|
|
| def get_device(): |
| return torch.device(f"cuda:{get_rank()}") |
|
|
|
|
| def load_module_weights( |
| path, module_name=None, ignore_modules=None, map_location=None |
| ) -> Tuple[dict, int, int]: |
| if module_name is not None and ignore_modules is not None: |
| raise ValueError("module_name and ignore_modules cannot be both set") |
| if map_location is None: |
| map_location = get_device() |
|
|
| ckpt = torch.load(path, map_location=map_location) |
| state_dict = ckpt["state_dict"] |
| state_dict_to_load = state_dict |
|
|
| if ignore_modules is not None: |
| state_dict_to_load = {} |
| for k, v in state_dict.items(): |
| ignore = any( |
| [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] |
| ) |
| if ignore: |
| continue |
| state_dict_to_load[k] = v |
|
|
| if module_name is not None: |
| state_dict_to_load = {} |
| for k, v in state_dict.items(): |
| m = re.match(rf"^{module_name}\.(.*)$", k) |
| if m is None: |
| continue |
| state_dict_to_load[m.group(1)] = v |
|
|
| return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] |
|
|
|
|
| def C(value: Any, epoch: int, global_step: int) -> float: |
| if isinstance(value, int) or isinstance(value, float): |
| pass |
| else: |
| value = config_to_primitive(value) |
| if not isinstance(value, list): |
| raise TypeError("Scalar specification only supports list, got", type(value)) |
| if len(value) == 3: |
| value = [0] + value |
| assert len(value) == 4 |
| start_step, start_value, end_value, end_step = value |
| if isinstance(end_step, int): |
| current_step = global_step |
| value = start_value + (end_value - start_value) * max( |
| min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 |
| ) |
| elif isinstance(end_step, float): |
| current_step = epoch |
| value = start_value + (end_value - start_value) * max( |
| min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 |
| ) |
| return value |
|
|
|
|
| def cleanup(): |
| gc.collect() |
| torch.cuda.empty_cache() |
| tcnn.free_temporary_memory() |
|
|
|
|
| def finish_with_cleanup(func: Callable): |
| def wrapper(*args, **kwargs): |
| out = func(*args, **kwargs) |
| cleanup() |
| return out |
|
|
| return wrapper |
|
|
|
|
| def _distributed_available(): |
| return torch.distributed.is_available() and torch.distributed.is_initialized() |
|
|
|
|
| def barrier(): |
| if not _distributed_available(): |
| return |
| else: |
| torch.distributed.barrier() |
|
|
|
|
| def broadcast(tensor, src=0): |
| if not _distributed_available(): |
| return tensor |
| else: |
| torch.distributed.broadcast(tensor, src=src) |
| return tensor |
|
|
|
|
| def enable_gradient(model, enabled: bool = True) -> None: |
| for param in model.parameters(): |
| param.requires_grad_(enabled) |
|
|
|
|
| def all_gather_batch(tensors): |
| """ |
| Performs all_gather operation on the provided tensors. |
| """ |
| |
| world_size = get_world_size() |
| |
| if world_size == 1: |
| if isinstance(tensors, list): |
| return tensors |
| return tensors |
| if not isinstance(tensors, list): |
| is_list = False |
| tensors = [tensors] |
| else: |
| is_list = True |
| output_tensor = [] |
| tensor_list = [] |
| for tensor in tensors: |
| tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] |
| dist.all_gather( |
| tensor_all, |
| tensor, |
| async_op=False |
| ) |
|
|
| tensor_list.append(tensor_all) |
|
|
| for tensor_all in tensor_list: |
| output_tensor.append(torch.cat(tensor_all, dim=0)) |
| if not is_list: |
| return output_tensor[0] |
| return output_tensor |