| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import random |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import numpy as np |
| | import torch |
| | from safetensors.torch import load_model |
| |
|
| | from .utils import ( |
| | MODEL_NAME, |
| | OPTIMIZER_NAME, |
| | RNG_STATE_NAME, |
| | SAFE_MODEL_NAME, |
| | SAFE_WEIGHTS_NAME, |
| | SAMPLER_NAME, |
| | SCALER_NAME, |
| | SCHEDULER_NAME, |
| | WEIGHTS_NAME, |
| | get_pretty_name, |
| | is_cuda_available, |
| | is_hpu_available, |
| | is_mlu_available, |
| | is_musa_available, |
| | is_sdaa_available, |
| | is_torch_version, |
| | is_torch_xla_available, |
| | is_xpu_available, |
| | load, |
| | save, |
| | ) |
| |
|
| |
|
| | if is_torch_version(">=", "2.4.0"): |
| | from torch.amp import GradScaler |
| | else: |
| | from torch.cuda.amp import GradScaler |
| |
|
| | if is_torch_xla_available(): |
| | import torch_xla.core.xla_model as xm |
| |
|
| | from .logging import get_logger |
| | from .state import PartialState |
| |
|
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | def save_accelerator_state( |
| | output_dir: str, |
| | model_states: list[dict], |
| | optimizers: list, |
| | schedulers: list, |
| | dataloaders: list, |
| | process_index: int, |
| | step: int, |
| | scaler: Optional[GradScaler] = None, |
| | save_on_each_node: bool = False, |
| | safe_serialization: bool = True, |
| | ): |
| | """ |
| | Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory. |
| | |
| | <Tip> |
| | |
| | If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native |
| | `pickle`. |
| | |
| | </Tip> |
| | |
| | Args: |
| | output_dir (`str` or `os.PathLike`): |
| | The name of the folder to save all relevant weights and states. |
| | model_states (`List[torch.nn.Module]`): |
| | A list of model states |
| | optimizers (`List[torch.optim.Optimizer]`): |
| | A list of optimizer instances |
| | schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`): |
| | A list of learning rate schedulers |
| | dataloaders (`List[torch.utils.data.DataLoader]`): |
| | A list of dataloader instances to save their sampler states |
| | process_index (`int`): |
| | The current process index in the Accelerator state |
| | step (`int`): |
| | The current step in the internal step tracker |
| | scaler (`torch.amp.GradScaler`, *optional*): |
| | An optional gradient scaler instance to save; |
| | save_on_each_node (`bool`, *optional*): |
| | Whether to save on every node, or only the main node. |
| | safe_serialization (`bool`, *optional*, defaults to `True`): |
| | Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). |
| | """ |
| | output_dir = Path(output_dir) |
| | |
| | for i, state in enumerate(model_states): |
| | weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME |
| | if i > 0: |
| | weights_name = weights_name.replace(".", f"_{i}.") |
| | output_model_file = output_dir.joinpath(weights_name) |
| | save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization) |
| | logger.info(f"Model weights saved in {output_model_file}") |
| | |
| | for i, opt in enumerate(optimizers): |
| | state = opt.state_dict() |
| | optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin" |
| | output_optimizer_file = output_dir.joinpath(optimizer_name) |
| | save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False) |
| | logger.info(f"Optimizer state saved in {output_optimizer_file}") |
| | |
| | for i, scheduler in enumerate(schedulers): |
| | state = scheduler.state_dict() |
| | scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin" |
| | output_scheduler_file = output_dir.joinpath(scheduler_name) |
| | save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False) |
| | logger.info(f"Scheduler state saved in {output_scheduler_file}") |
| | |
| | for i, dataloader in enumerate(dataloaders): |
| | sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" |
| | output_sampler_file = output_dir.joinpath(sampler_name) |
| | |
| | from .data_loader import IterableDatasetShard, SeedableRandomSampler |
| |
|
| | if isinstance(dataloader.dataset, IterableDatasetShard): |
| | sampler = dataloader.get_sampler() |
| | if isinstance(sampler, SeedableRandomSampler): |
| | save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False) |
| | if getattr(dataloader, "use_stateful_dataloader", False): |
| | dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin" |
| | output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name) |
| | state_dict = dataloader.state_dict() |
| | torch.save(state_dict, output_dataloader_state_dict_file) |
| | logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}") |
| |
|
| | |
| | if scaler is not None: |
| | state = scaler.state_dict() |
| | output_scaler_file = output_dir.joinpath(SCALER_NAME) |
| | torch.save(state, output_scaler_file) |
| | logger.info(f"Gradient scaler state saved in {output_scaler_file}") |
| | |
| | states = {} |
| | states_name = f"{RNG_STATE_NAME}_{process_index}.pkl" |
| | states["step"] = step |
| | states["random_state"] = random.getstate() |
| | states["numpy_random_seed"] = np.random.get_state() |
| | states["torch_manual_seed"] = torch.get_rng_state() |
| | if is_xpu_available(): |
| | states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all() |
| | if is_mlu_available(): |
| | states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all() |
| | elif is_sdaa_available(): |
| | states["torch_sdaa_manual_seed"] = torch.sdaa.get_rng_state_all() |
| | elif is_musa_available(): |
| | states["torch_musa_manual_seed"] = torch.musa.get_rng_state_all() |
| | if is_hpu_available(): |
| | states["torch_hpu_manual_seed"] = torch.hpu.get_rng_state_all() |
| | if is_cuda_available(): |
| | states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all() |
| | if is_torch_xla_available(): |
| | states["xm_seed"] = xm.get_rng_state() |
| | output_states_file = output_dir.joinpath(states_name) |
| | torch.save(states, output_states_file) |
| | logger.info(f"Random states saved in {output_states_file}") |
| | return output_dir |
| |
|
| |
|
| | def load_accelerator_state( |
| | input_dir, |
| | models, |
| | optimizers, |
| | schedulers, |
| | dataloaders, |
| | process_index, |
| | scaler=None, |
| | map_location=None, |
| | load_kwargs=None, |
| | **load_model_func_kwargs, |
| | ): |
| | """ |
| | Loads states of the models, optimizers, scaler, and RNG generators from a given directory. |
| | |
| | Args: |
| | input_dir (`str` or `os.PathLike`): |
| | The name of the folder to load all relevant weights and states. |
| | models (`List[torch.nn.Module]`): |
| | A list of model instances |
| | optimizers (`List[torch.optim.Optimizer]`): |
| | A list of optimizer instances |
| | schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`): |
| | A list of learning rate schedulers |
| | process_index (`int`): |
| | The current process index in the Accelerator state |
| | scaler (`torch.amp.GradScaler`, *optional*): |
| | An optional *GradScaler* instance to load |
| | map_location (`str`, *optional*): |
| | What device to load the optimizer state onto. Should be one of either "cpu" or "on_device". |
| | load_kwargs (`dict`, *optional*): |
| | Additional arguments that can be passed to the `load` function. |
| | load_model_func_kwargs (`dict`, *optional*): |
| | Additional arguments that can be passed to the model's `load_state_dict` method. |
| | |
| | Returns: |
| | `dict`: Contains the `Accelerator` attributes to override while loading the state. |
| | """ |
| | |
| | override_attributes = dict() |
| | if map_location not in [None, "cpu", "on_device"]: |
| | raise TypeError( |
| | "Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`" |
| | ) |
| | if map_location is None: |
| | map_location = "cpu" |
| | elif map_location == "on_device": |
| | map_location = PartialState().device |
| |
|
| | if load_kwargs is None: |
| | load_kwargs = {} |
| |
|
| | input_dir = Path(input_dir) |
| | |
| | for i, model in enumerate(models): |
| | ending = f"_{i}" if i > 0 else "" |
| | input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors") |
| | if input_model_file.exists(): |
| | load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs) |
| | else: |
| | |
| | input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin") |
| | state_dict = load(input_model_file, map_location=map_location) |
| | model.load_state_dict(state_dict, **load_model_func_kwargs) |
| | logger.info("All model weights loaded successfully") |
| |
|
| | |
| | for i, opt in enumerate(optimizers): |
| | optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin" |
| | input_optimizer_file = input_dir.joinpath(optimizer_name) |
| | optimizer_state = load(input_optimizer_file, map_location=map_location, **load_kwargs) |
| | optimizers[i].load_state_dict(optimizer_state) |
| | logger.info("All optimizer states loaded successfully") |
| |
|
| | |
| | for i, scheduler in enumerate(schedulers): |
| | scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin" |
| | input_scheduler_file = input_dir.joinpath(scheduler_name) |
| | scheduler_state = load(input_scheduler_file, **load_kwargs) |
| | scheduler.load_state_dict(scheduler_state) |
| | logger.info("All scheduler states loaded successfully") |
| |
|
| | for i, dataloader in enumerate(dataloaders): |
| | sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" |
| | input_sampler_file = input_dir.joinpath(sampler_name) |
| | |
| | from .data_loader import IterableDatasetShard, SeedableRandomSampler |
| |
|
| | if isinstance(dataloader.dataset, IterableDatasetShard): |
| | sampler = dataloader.get_sampler() |
| | if isinstance(sampler, SeedableRandomSampler): |
| | sampler = dataloader.set_sampler(load(input_sampler_file)) |
| | if getattr(dataloader, "use_stateful_dataloader", False): |
| | dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin" |
| | input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name) |
| | if input_dataloader_state_dict_file.exists(): |
| | state_dict = load(input_dataloader_state_dict_file, **load_kwargs) |
| | dataloader.load_state_dict(state_dict) |
| | logger.info("All dataloader sampler states loaded successfully") |
| |
|
| | |
| | if scaler is not None: |
| | input_scaler_file = input_dir.joinpath(SCALER_NAME) |
| | scaler_state = load(input_scaler_file) |
| | scaler.load_state_dict(scaler_state) |
| | logger.info("GradScaler state loaded successfully") |
| |
|
| | |
| | try: |
| | states = load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl")) |
| | if "step" in states: |
| | override_attributes["step"] = states["step"] |
| | random.setstate(states["random_state"]) |
| | np.random.set_state(states["numpy_random_seed"]) |
| | torch.set_rng_state(states["torch_manual_seed"]) |
| | if is_xpu_available(): |
| | torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"]) |
| | if is_mlu_available(): |
| | torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"]) |
| | elif is_sdaa_available(): |
| | torch.sdaa.set_rng_state_all(states["torch_sdaa_manual_seed"]) |
| | elif is_musa_available(): |
| | torch.musa.set_rng_state_all(states["torch_musa_manual_seed"]) |
| | else: |
| | torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"]) |
| | if is_torch_xla_available(): |
| | xm.set_rng_state(states["xm_seed"]) |
| | logger.info("All random states loaded successfully") |
| | except Exception: |
| | logger.info("Could not load random states") |
| |
|
| | return override_attributes |
| |
|
| |
|
| | def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False): |
| | """ |
| | Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl` |
| | """ |
| | |
| | save_location = Path(path) / f"custom_checkpoint_{index}.pkl" |
| | logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}") |
| | save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node) |
| |
|
| |
|
| | def load_custom_state(obj, path, index: int = 0): |
| | """ |
| | Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl`. Will always set `weights_only=False` when |
| | loading the state. |
| | """ |
| | load_location = f"{path}/custom_checkpoint_{index}.pkl" |
| | logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}") |
| | obj.load_state_dict(load(load_location, map_location="cpu", weights_only=False)) |
| |
|