| import argparse |
| import os |
|
|
| from transformers import AutoConfig, AutoProcessor |
|
|
| from lingbotvla.checkpoint import bytecheckpoint_ckpt_to_state_dict |
| from lingbotvla.models import save_model_weights |
| from lingbotvla.utils import helper |
|
|
|
|
| logger = helper.create_logger(__name__) |
|
|
|
|
| def merge_to_hf_pt(load_dir: str, save_path: str, model_assets_dir: str = None): |
| |
| state_dict = bytecheckpoint_ckpt_to_state_dict( |
| save_checkpoint_path=load_dir, |
| output_dir=save_path, |
| ) |
| if model_assets_dir is not None: |
| config = AutoConfig.from_pretrained(model_assets_dir) |
| processor = AutoProcessor.from_pretrained(model_assets_dir, trust_remote_code=True) |
|
|
| save_model_weights(save_path, state_dict, model_assets=[config, processor]) |
| else: |
| save_model_weights(save_path, state_dict) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--load-dir", type=str, required=True) |
| parser.add_argument("--save-dir", type=str, default=None) |
| parser.add_argument("--model_assets_dir", type=str, default=None) |
| args = parser.parse_args() |
| load_dir = args.load_dir |
| save_dir = os.path.join(load_dir, "hf_ckpt") if args.save_dir is None else args.save_dir |
| model_assets_dir = args.model_assets_dir |
| logger.info(f"Merge Args: {args}") |
| merge_to_hf_pt(load_dir, save_dir, model_assets_dir) |
| logger.info(f"Merge to hf pt success! Save to: {save_dir}") |
|
|