lingbot-vla / lingbotvla /checkpoint /format_utils.py
bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC
from typing import Any, Dict, Union
import torch
from ..utils.import_utils import is_torch_version_greater_than
from ..utils.logging import get_logger
if is_torch_version_greater_than("2.4"):
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
else:
STATE_DICT_TYPE = ABC
logger = get_logger(__name__)
_MODEL_DIR = "model"
_EMA_DIR = "ema"
def ckpt_to_state_dict(
save_checkpoint_path: Union[str, os.PathLike],
output_dir: Union[str, os.PathLike],
ckpt_manager: str = "bytecheckpoint",
ema: bool = False,
) -> Dict[str, Any]:
"""
Interface to convert a checkpoint to a state_dict.
Supported checkpoint managers:
- bytecheckpoint
- dcp
- native
Args:
save_checkpoint_path: Path to the checkpoint.
output_dir: Path to the output directory.
ckpt_manager: Checkpoint manager.
Returns:
state_dict: State dict.
"""
if ckpt_manager == "bytecheckpoint":
state_dict = bytecheckpoint_ckpt_to_state_dict(save_checkpoint_path, output_dir)
elif ckpt_manager == "dcp":
state_dict = dcp_to_torch_state_dict(save_checkpoint_path, ema)
elif ckpt_manager == "native":
model_dir = os.path.join(save_checkpoint_path, _MODEL_DIR)
if os.path.exists(model_dir):
save_checkpoint_path = model_dir
state_dict = torch.load(save_checkpoint_path)
else:
raise ValueError(f"Unknown checkpoint manager: {ckpt_manager}")
return state_dict
def bytecheckpoint_ckpt_to_state_dict(
save_checkpoint_path: Union[str, os.PathLike], output_dir: Union[str, os.PathLike]
):
"""
Given a directory containing an Bytecheckpoint checkpoint, this function will convert it into a
Torch state_dict.
Args:
save_checkpoint_path: Directory containing the Bytecheckpoint checkpoint.
output_dir: Directory to save the converted checkpoint.
"""
from bytecheckpoint.utilities.ckpt_format.merge_tool import bytecheckpoint_ckpt_to_pytorch_ckpt
state_dict = bytecheckpoint_ckpt_to_pytorch_ckpt(
save_path=save_checkpoint_path,
output_path=output_dir,
framework="fsdp",
model_only=True,
return_dict=True,
)
return state_dict["model"]
def dcp_to_torch_state_dict(save_checkpoint_path: Union[str, os.PathLike], ema: bool = False) -> STATE_DICT_TYPE:
"""
Given a directory containing a DCP checkpoint, this function will convert it into a
Torch state_dict.
Args:
save_checkpoint_path: Directory containing the DCP checkpoint.
.. warning::
To avoid OOM, it's recommended to only run this function on a single rank.
"""
if ema:
model_dir = os.path.join(save_checkpoint_path, _EMA_DIR)
else:
model_dir = os.path.join(save_checkpoint_path, _MODEL_DIR)
if os.path.exists(model_dir):
save_checkpoint_path = model_dir
# Load the state_dict from the DCP checkpoint
state_dict: STATE_DICT_TYPE = {}
_load_state_dict(
state_dict,
storage_reader=FileSystemReader(save_checkpoint_path),
planner=_EmptyStateDictLoadPlanner(),
no_dist=True,
)
if "state" in state_dict:
state_dict = state_dict["state"]
return state_dict["model"]