|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from typing import Any, Set |
|
|
|
|
|
import torch |
|
|
|
|
|
from cosmos_predict1.checkpointer.ddp import Checkpointer as DDPCheckpointer |
|
|
from cosmos_predict1.utils import distributed, log |
|
|
from cosmos_predict1.utils.model import Model |
|
|
|
|
|
|
|
|
class Checkpointer(DDPCheckpointer): |
|
|
""" |
|
|
Checkpointer class for PEFT in distributed training. This class is similar to the DDP checkpointer, |
|
|
with the exception that the `broadcast_via_filesystem` functionality is not supported, and it supports |
|
|
loading pre-trained model without any postfix. |
|
|
|
|
|
Note: |
|
|
- Fully Sharded Data Parallelism (FSDP) is not supported by this checkpointer. |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
if not self.broadcast_via_filesystem: |
|
|
raise ValueError("self.broadcast_via_filesystem=False is not implemented for PEFT checkpointer.") |
|
|
|
|
|
def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: |
|
|
""" |
|
|
Overwrite the `add_type_postfix_to_checkpoint_path` function of the base class (DDP checkpointer) |
|
|
to load pre-trained model without any postfix. |
|
|
""" |
|
|
checkpoint_path = super().add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) |
|
|
checkpoint_path = checkpoint_path.replace("model_model.pt", "model.pt") |
|
|
return checkpoint_path |
|
|
|
|
|
def load_broadcast_state_dict(self, checkpoint_path: str, model: Model, resume_keys: Set) -> dict[str, Any]: |
|
|
""" |
|
|
Load state_dict and broadcast for PEFT checkpointer. |
|
|
|
|
|
This function is identical to the `load_broadcast_state_dict` function of the base class (DDP checkpointer), |
|
|
with the exception that the `broadcast_via_filesystem` functionality is not supported. |
|
|
|
|
|
Args: |
|
|
checkpoint_path (str): The base path of the checkpoint. |
|
|
model (Model): The model being loaded. |
|
|
resume_keys (Set): Set of keys to resume from the checkpoint. |
|
|
|
|
|
Returns: |
|
|
dict[str, Any]: A dictionary containing the loaded state for each resumed key. |
|
|
""" |
|
|
state_dict = {} |
|
|
sorted_resume_keys = sorted(resume_keys) |
|
|
|
|
|
if self.rank_dp_w_cp == 0: |
|
|
for key in sorted_resume_keys: |
|
|
_ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) |
|
|
local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) |
|
|
if os.path.exists(local_cache_path): |
|
|
|
|
|
self.print(f"Checkpoint is already in local cache: {local_cache_path}. Loading...") |
|
|
_state_dict = torch.load( |
|
|
local_cache_path, map_location=lambda storage, loc: storage, weights_only=False |
|
|
) |
|
|
else: |
|
|
|
|
|
self.print(f"Loading checkpoint from: {_ckpt_path}") |
|
|
_state_dict = torch.load(_ckpt_path, map_location=lambda storage, loc: storage, weights_only=False) |
|
|
state_dict[key] = _state_dict |
|
|
|
|
|
|
|
|
distributed.barrier() |
|
|
|
|
|
|
|
|
log.info( |
|
|
"Start broadcasting checkpoint from the source rank to all other ranks in the same DDP group.", |
|
|
rank0_only=True, |
|
|
) |
|
|
for key in sorted_resume_keys: |
|
|
if self.broadcast_via_filesystem: |
|
|
|
|
|
if self.rank_dp_w_cp != 0: |
|
|
_ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) |
|
|
local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) |
|
|
if os.path.exists(local_cache_path): |
|
|
self.print(f"Loading checkpoint from: {local_cache_path}") |
|
|
state_dict[key] = torch.load( |
|
|
local_cache_path, map_location=lambda storage, loc: storage, weights_only=False |
|
|
) |
|
|
else: |
|
|
self.print(f"Loading checkpoint from: {_ckpt_path}") |
|
|
state_dict[key] = torch.load( |
|
|
_ckpt_path, map_location=lambda storage, loc: storage, weights_only=False |
|
|
) |
|
|
|
|
|
else: |
|
|
raise ValueError("self.broadcast_via_filesystem=False is not implemented for PEFT checkpointer.") |
|
|
|
|
|
return state_dict |
|
|
|