import argparse import os from transformers import AutoConfig, AutoProcessor from lingbotvla.checkpoint import bytecheckpoint_ckpt_to_state_dict from lingbotvla.models import save_model_weights from lingbotvla.utils import helper logger = helper.create_logger(__name__) def merge_to_hf_pt(load_dir: str, save_path: str, model_assets_dir: str = None): # save model in huggingface's format state_dict = bytecheckpoint_ckpt_to_state_dict( save_checkpoint_path=load_dir, output_dir=save_path, ) if model_assets_dir is not None: config = AutoConfig.from_pretrained(model_assets_dir) processor = AutoProcessor.from_pretrained(model_assets_dir, trust_remote_code=True) save_model_weights(save_path, state_dict, model_assets=[config, processor]) else: save_model_weights(save_path, state_dict) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--load-dir", type=str, required=True) parser.add_argument("--save-dir", type=str, default=None) parser.add_argument("--model_assets_dir", type=str, default=None) args = parser.parse_args() load_dir = args.load_dir save_dir = os.path.join(load_dir, "hf_ckpt") if args.save_dir is None else args.save_dir model_assets_dir = args.model_assets_dir logger.info(f"Merge Args: {args}") merge_to_hf_pt(load_dir, save_dir, model_assets_dir) logger.info(f"Merge to hf pt success! Save to: {save_dir}")