| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import random |
| | from pathlib import Path |
| | from typing import List |
| |
|
| | import numpy as np |
| | import torch |
| | from torch.cuda.amp import GradScaler |
| |
|
| | from .utils import ( |
| | MODEL_NAME, |
| | OPTIMIZER_NAME, |
| | RNG_STATE_NAME, |
| | SCALER_NAME, |
| | SCHEDULER_NAME, |
| | get_pretty_name, |
| | is_tpu_available, |
| | is_xpu_available, |
| | save, |
| | ) |
| |
|
| |
|
| | if is_tpu_available(check_device=False): |
| | import torch_xla.core.xla_model as xm |
| |
|
| | from .logging import get_logger |
| | from .state import PartialState |
| |
|
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | def save_accelerator_state( |
| | output_dir: str, |
| | model_states: List[dict], |
| | optimizers: list, |
| | schedulers: list, |
| | process_index: int, |
| | scaler: GradScaler = None, |
| | ): |
| | """ |
| | Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory. |
| | |
| | Args: |
| | output_dir (`str` or `os.PathLike`): |
| | The name of the folder to save all relevant weights and states. |
| | model_states (`List[torch.nn.Module]`): |
| | A list of model states |
| | optimizers (`List[torch.optim.Optimizer]`): |
| | A list of optimizer instances |
| | schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`): |
| | A list of learning rate schedulers |
| | process_index (`int`): |
| | The current process index in the Accelerator state |
| | scaler (`torch.cuda.amp.GradScaler`, *optional*): |
| | An optional gradient scaler instance to save |
| | """ |
| | |
| | for i, state in enumerate(model_states): |
| | weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin" |
| | output_model_file = os.path.join(output_dir, weights_name) |
| | save(state, output_model_file) |
| | logger.info(f"Model weights saved in {output_model_file}") |
| | |
| | for i, opt in enumerate(optimizers): |
| | state = opt.state_dict() |
| | optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin" |
| | output_optimizer_file = os.path.join(output_dir, optimizer_name) |
| | save(state, output_optimizer_file) |
| | logger.info(f"Optimizer state saved in {output_optimizer_file}") |
| | |
| | for i, scheduler in enumerate(schedulers): |
| | state = scheduler.state_dict() |
| | scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin" |
| | output_scheduler_file = os.path.join(output_dir, scheduler_name) |
| | save(state, output_scheduler_file) |
| | logger.info(f"Scheduler state saved in {output_scheduler_file}") |
| | |
| | if scaler is not None: |
| | state = scaler.state_dict() |
| | output_scaler_file = os.path.join(output_dir, SCALER_NAME) |
| | torch.save(state, output_scaler_file) |
| | logger.info(f"Gradient scaler state saved in {output_scaler_file}") |
| | |
| | states = {} |
| | states_name = f"{RNG_STATE_NAME}_{process_index}.pkl" |
| | states["random_state"] = random.getstate() |
| | states["numpy_random_seed"] = np.random.get_state() |
| | states["torch_manual_seed"] = torch.get_rng_state() |
| | if is_xpu_available(): |
| | states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all() |
| | else: |
| | states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all() |
| | if is_tpu_available(): |
| | states["xm_seed"] = xm.get_rng_state() |
| | output_states_file = os.path.join(output_dir, states_name) |
| | torch.save(states, output_states_file) |
| | logger.info(f"Random states saved in {output_states_file}") |
| | return output_dir |
| |
|
| |
|
| | def load_accelerator_state( |
| | input_dir, |
| | models, |
| | optimizers, |
| | schedulers, |
| | process_index, |
| | scaler=None, |
| | map_location=None, |
| | **load_model_func_kwargs, |
| | ): |
| | """ |
| | Loads states of the models, optimizers, scaler, and RNG generators from a given directory. |
| | |
| | Args: |
| | input_dir (`str` or `os.PathLike`): |
| | The name of the folder to load all relevant weights and states. |
| | models (`List[torch.nn.Module]`): |
| | A list of model instances |
| | optimizers (`List[torch.optim.Optimizer]`): |
| | A list of optimizer instances |
| | schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`): |
| | A list of learning rate schedulers |
| | process_index (`int`): |
| | The current process index in the Accelerator state |
| | scaler (`torch.cuda.amp.GradScaler`, *optional*): |
| | An optional *GradScaler* instance to load |
| | map_location (`str`, *optional*): |
| | What device to load the optimizer state onto. Should be one of either "cpu" or "on_device". |
| | load_model_func_kwargs (`dict`, *optional*): |
| | Additional arguments that can be passed to the model's `load_state_dict` method. |
| | """ |
| | if map_location not in [None, "cpu", "on_device"]: |
| | raise TypeError( |
| | "Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`" |
| | ) |
| | if map_location is None: |
| | map_location = "cpu" |
| | elif map_location == "on_device": |
| | map_location = PartialState().device |
| | |
| | for i, model in enumerate(models): |
| | weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin" |
| | input_model_file = os.path.join(input_dir, weights_name) |
| | models[i].load_state_dict(torch.load(input_model_file, map_location=map_location), **load_model_func_kwargs) |
| | logger.info("All model weights loaded successfully") |
| |
|
| | |
| | for i, opt in enumerate(optimizers): |
| | optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin" |
| | input_optimizer_file = os.path.join(input_dir, optimizer_name) |
| | optimizer_state = torch.load(input_optimizer_file, map_location=map_location) |
| | optimizers[i].load_state_dict(optimizer_state) |
| | logger.info("All optimizer states loaded successfully") |
| |
|
| | |
| | for i, scheduler in enumerate(schedulers): |
| | scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin" |
| | input_scheduler_file = os.path.join(input_dir, scheduler_name) |
| | scheduler.load_state_dict(torch.load(input_scheduler_file)) |
| | logger.info("All scheduler states loaded successfully") |
| |
|
| | |
| | if scaler is not None: |
| | input_scaler_file = os.path.join(input_dir, SCALER_NAME) |
| | scaler.load_state_dict(torch.load(input_scaler_file)) |
| | logger.info("GradScaler state loaded successfully") |
| |
|
| | |
| | try: |
| | states = torch.load(os.path.join(input_dir, f"{RNG_STATE_NAME}_{process_index}.pkl")) |
| | random.setstate(states["random_state"]) |
| | np.random.set_state(states["numpy_random_seed"]) |
| | torch.set_rng_state(states["torch_manual_seed"]) |
| | if is_xpu_available(): |
| | torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"]) |
| | else: |
| | torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"]) |
| | if is_tpu_available(): |
| | xm.set_rng_state(states["xm_seed"]) |
| | logger.info("All random states loaded successfully") |
| | except Exception: |
| | logger.info("Could not load random states") |
| |
|
| |
|
| | def save_custom_state(obj, path, index: int = 0): |
| | """ |
| | Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl` |
| | """ |
| | |
| | save_location = Path(path) / f"custom_checkpoint_{index}.pkl" |
| | logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}") |
| | torch.save(obj.state_dict(), save_location) |
| |
|
| |
|
| | def load_custom_state(obj, path, index: int = 0): |
| | """ |
| | Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl` |
| | """ |
| | load_location = f"{path}/custom_checkpoint_{index}.pkl" |
| | logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}") |
| | obj.load_state_dict(torch.load(load_location, map_location="cpu")) |
| |
|