| |
|
|
| import json |
| import logging |
| import os |
| import re |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.distributed.checkpoint as dcp |
| import torch.nn as nn |
| import torch.optim.optimizer |
| from omegaconf import OmegaConf |
| from torch.distributed._tensor import DeviceMesh |
| from torch.distributed.checkpoint import FileSystemReader |
| from torch.distributed.checkpoint.format_utils import dcp_to_torch_save |
| from torch.distributed.checkpoint.state_dict import (get_model_state_dict, |
| get_state_dict, |
| set_state_dict) |
|
|
| from core.distributed import get_is_master |
|
|
| logger = logging.getLogger("CHECKPOINT") |
|
|
| FOLDER_NAME = "{:010d}" |
| RE_FOLDER = r"\d{10}" |
|
|
| RE_CKPT = r"__\d_\d\.distcp" |
|
|
| CONSOLIDATE_FOLDER = "consolidated" |
| CONSOLIDATE_NAME = "consolidated.pth" |
|
|
| CONFIG_NAME = "params.json" |
| TRAIN_STATE_NAME = "train_state_{:05d}.json" |
| RE_DIGITS = re.compile(r"\d+") |
|
|
|
|
| @dataclass |
| class SaveEvery: |
| every: int = 1000 |
| keep: int = 0 |
|
|
|
|
| @dataclass |
| class CheckpointArgs: |
| dump: SaveEvery = field(default_factory=SaveEvery) |
| eval: SaveEvery = field(default_factory=SaveEvery) |
| path: Optional[str] = None |
| init_ckpt_path: Optional[str] = None |
| vision_model_path: Optional[str] = None |
| is_consolidated_model: bool = False |
| continue_training_from_init: bool = False |
|
|
|
|
| def _get_key_step(name: str): |
| return int(re.findall(RE_DIGITS, name)[-1]) |
|
|
|
|
| def consolidate_checkpoints(ckpt_dir: str): |
| """ |
| Consolidates all FSDP checkpoints in a directory to a single file |
| Consolidate checkpoint is saved in a subdirectory of ckpt_dir |
| |
| Parameters: |
| ckpt_dir: str - path to the directory containing the checkpoints |
| |
| Returns the path to the consolidated checkpoint |
| """ |
| consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER |
| if not (consolidate_path / CONSOLIDATE_NAME).exists(): |
| consolidate_path.mkdir(exist_ok=True) |
| logger.info(f"Consolidating to: {str(consolidate_path)}") |
| dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) |
| (consolidate_path / CONFIG_NAME).write_text( |
| (Path(ckpt_dir) / CONFIG_NAME).read_text() |
| ) |
| logger.info("Consolidated !") |
| return consolidate_path |
|
|
|
|
| def load_from_checkpoint( |
| ckpt_dir: str, |
| model: nn.Module, |
| optimizer: Optional[torch.optim.Optimizer] = None, |
| model_key: str = "model", |
| optim_key: str = "optim", |
| ): |
| if not (Path(ckpt_dir) / ".metadata").exists(): |
| raise ValueError( |
| f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" |
| ) |
|
|
| state_dict = {} |
| if optimizer is not None: |
| state_dict[model_key], state_dict[optim_key] = get_state_dict(model, optimizer) |
| else: |
| state_dict[model_key] = get_model_state_dict(model) |
| if model_key == "": |
| state_dict = state_dict.pop(model_key) |
|
|
| dcp.load(state_dict, checkpoint_id=ckpt_dir) |
|
|
|
|
| class CheckpointManager: |
| def __init__(self, args: CheckpointArgs): |
| self.path = args.path |
| self.dump_every = args.dump |
| self.eval_every = args.eval |
| self.init_ckpt_path = args.init_ckpt_path |
| self.continue_training_from_init = args.continue_training_from_init |
|
|
| assert os.path.exists( |
| self.path |
| ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" |
|
|
| self.existing_saves = self.get_existing_saves() |
|
|
| def get_existing_saves(self) -> List[Path]: |
| folders = [ |
| p |
| for p in Path(self.path).iterdir() |
| if p.is_dir() and re.match(RE_FOLDER, p.name) |
| ] |
| folders.sort(key=lambda p: _get_key_step(p.name)) |
| return folders |
|
|
| def clean_up(self): |
| logger.info("Cleaning up checkpoints...") |
| dump_folders = [] |
| eval_folders = [] |
| other_folders = [] |
| for p in self.existing_saves: |
| is_dump = _get_key_step(p.name) % self.dump_every.every == 0 |
| is_eval = _get_key_step(p.name) % self.eval_every.every == 0 |
| if is_dump: |
| dump_folders.append(p) |
| if is_eval: |
| eval_folders.append(p) |
| if not (is_dump or is_eval): |
| other_folders.append(p) |
|
|
| logger.info(f"Dump folders: {dump_folders}") |
| logger.info(f"Eval folders: {eval_folders}") |
| logger.info(f"Other folders: {other_folders}") |
|
|
| if self.dump_every.keep > 0: |
| dump_folders = dump_folders[-self.dump_every.keep :] |
| if self.eval_every.keep > 0: |
| eval_folders = eval_folders[-self.eval_every.keep :] |
|
|
| folder_to_keep = set(other_folders + dump_folders + eval_folders) |
| folder_to_remove = set(self.existing_saves) - folder_to_keep |
|
|
| logger.info(f"Removing folders: {folder_to_remove}") |
|
|
| if dist.get_rank() == 0: |
| for folder in folder_to_remove: |
| for file in folder.iterdir(): |
| if file.is_file(): |
| file.unlink() |
| elif file.is_dir(): |
| assert file.name in [CONSOLIDATE_FOLDER] |
| for f in file.iterdir(): |
| f.unlink() |
| file.rmdir() |
| folder.rmdir() |
|
|
| dist.barrier() |
|
|
| self.existing_saves = list(folder_to_keep) |
| self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) |
|
|
| def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: |
| path = None |
| for p in reversed(self.existing_saves): |
| if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): |
| path = p |
| break |
| return path |
|
|
| def _create_folder(self, base_path: Path, folder_name: str) -> Path: |
| folder = base_path / folder_name |
| if get_is_master(): |
| folder.mkdir(parents=False, exist_ok=True) |
| if dist.is_initialized(): |
| dist.barrier() |
| return folder |
|
|
| def _get_dp_tp_mesh( |
| self, device_mesh: Optional[DeviceMesh] = None |
| ) -> Tuple[int, int]: |
| dp_rank = 0 |
| tp_rank = 0 |
| if device_mesh is not None: |
| if "dp_replicate" in device_mesh.mesh_dim_names: |
| dp_rank = device_mesh.get_local_rank("dp_replicate") |
| if "dp_shard" in device_mesh.mesh_dim_names: |
| dp_rank = dp_rank * device_mesh[ |
| "dp_replicate" |
| ].size() + device_mesh.get_local_rank("dp_shard") |
| if "tp" in device_mesh.mesh_dim_names: |
| tp_rank = device_mesh.get_local_rank("tp") |
| return dp_rank, tp_rank |
|
|
| @torch.no_grad() |
| def get_state_dict( |
| self, |
| model, |
| optimizer, |
| ): |
| model_sd, optim_sd = get_state_dict(model, optimizer) |
| return {"model": model_sd, "optim": optim_sd} |
|
|
| def save( |
| self, |
| model, |
| optimizer, |
| train_state, |
| config, |
| device_mesh: Optional[DeviceMesh] = None, |
| ) -> bool: |
|
|
| |
| path = Path(self.path) |
| curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) |
| logger.info(f"Saving to: {str(curr_save_dir)}") |
|
|
| if dist.is_initialized(): |
| dist.barrier() |
|
|
| logger.info("Saving...") |
| state_dict = self.get_state_dict(model, optimizer) |
| dcp.save(state_dict, checkpoint_id=curr_save_dir) |
| logger.info("State dict saved!") |
|
|
| if dist.is_initialized(): |
| dist.barrier() |
|
|
| if get_is_master(): |
| with open(curr_save_dir / CONFIG_NAME, "w") as f: |
| json.dump( |
| OmegaConf.to_container(OmegaConf.structured(config), resolve=True), |
| f, |
| indent=4, |
| ) |
|
|
| |
| dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) |
| if tp_rank == 0: |
| train_state_name = TRAIN_STATE_NAME.format(dp_rank) |
| logger.info( |
| f"Saving train state to: {str(curr_save_dir / train_state_name)}" |
| ) |
| |
| with open(curr_save_dir / train_state_name, "w") as f: |
| json.dump(train_state.state_dict(), f) |
| logger.info("Train state saved !") |
|
|
| self.existing_saves.append(curr_save_dir) |
|
|
| self.clean_up() |
|
|
| if dist.is_initialized(): |
| dist.barrier() |
| return True |
|
|
| @torch.no_grad() |
| def load( |
| self, |
| model: nn.Module, |
| optimizer, |
| train_state, |
| device_mesh: DeviceMesh, |
| path: Optional[Path] = None, |
| ): |
| dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) |
| |
| path = path or self.get_last_step_path(dp_rank=dp_rank) |
| |
| if path is None: |
| |
| return |
|
|
| |
| train_state_name = TRAIN_STATE_NAME.format(dp_rank) |
| logger.info("Reloading train state") |
| with open(path / train_state_name, "r") as f: |
| train_state_dict = json.load(f) |
| train_state.load_state_dict(train_state_dict) |
| logger.info("Train state reloaded") |
|
|
| logger.info(f"Loading from: {str(path)}") |
| state_dict = self.get_state_dict( |
| model=model, |
| optimizer=optimizer, |
| ) |
| dcp.load(state_dict, checkpoint_id=path) |
| logger.info("State dict loaded.") |
|
|
| logger.info("Reloading model and optim") |
|
|
| set_state_dict( |
| model, |
| optimizer, |
| model_state_dict=state_dict["model"], |
| optim_state_dict=state_dict["optim"], |
| ) |
| logger.info("Model and optim reloaded") |
|
|
| @classmethod |
| def instantiate_and_make_dir(cls, args: CheckpointArgs): |
| if get_is_master(): |
| os.makedirs(args.path, exist_ok=True) |
| dist.barrier() |
|
|
| return cls(args) |
|
|
|
|
| def get_consolidated_ckpt_path(ckpt_dir: Path, mp_rank: int = 0, mp_size: int = 1): |
| if mp_size == 1: |
| assert mp_rank == 0 |
| no_rank_path = ckpt_dir / "consolidated.pth" |
| if no_rank_path.exists(): |
| return no_rank_path |
| return ckpt_dir / f"consolidated.{mp_rank:02d}.pth" |
|
|
|
|
| def load_consolidated_checkpoint( |
| model: nn.Module, |
| consolidated_path: str, |
| vision_model_path: Optional[str] = None, |
| ): |
| """ |
| Loads a consolidated checkpoint into the model. |
| This version supports both: |
| - a single file named 'consolidated.pth' |
| - multiple parts named like 'consolidated.00.pth', 'consolidated.01.pth', etc. |
| """ |
| ckpt_path = Path(consolidated_path) |
| cp_file = get_consolidated_ckpt_path(ckpt_path, mp_rank=0, mp_size=1) |
| if cp_file.exists(): |
| |
| st_dict = torch.load(cp_file, weights_only=True) |
| if "model" in st_dict: |
| st_dict = st_dict["model"] |
| else: |
| |
| checkpoint_files = sorted(ckpt_path.glob("consolidated.*.pth")) |
| if not checkpoint_files: |
| raise FileNotFoundError( |
| f"No consolidated checkpoint file found in {ckpt_path}." |
| ) |
| st_dict = {} |
| for ckpt_file in checkpoint_files: |
| part = torch.load(ckpt_file, weights_only=True) |
| |
| if "model" in part: |
| part = part["model"] |
| |
| st_dict.update(part) |
|
|
| model.vision_projector.init_tensors() |
| model.vision_model.init_tensors() |
| model.rope_embeddings.reset_parameters() |
|
|
| if vision_model_path is not None: |
| model.vision_model.load_ckpt(vision_model_path) |
|
|
| missing_keys, unexpected_keys = model.load_state_dict(st_dict, strict=False) |
| missing_keys = [k for k in missing_keys if "tied_module.weight" not in k] |
| if vision_model_path is not None: |
| |
| missing_keys = [k for k in missing_keys if "vision_model." not in k] |
| if len(missing_keys) > 0: |
| logger.warning(f"Missing keys when reloading: {missing_keys}") |
| if len(unexpected_keys) > 0: |
| logger.warning(f"Unexpected keys when reloading: {unexpected_keys}") |
|
|