# Copyright (c) OpenMMLab. All rights reserved. import torch from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, PROMPT_TEMPLATE, SYSTEM_TEMPLATE) import argparse import os.path as osp from mmengine.config import Config, DictAction from mmengine.fileio import PetrelBackend, get_file_backend from xtuner.configs import cfgs_name_path from xtuner.model.utils import guess_load_checkpoint from xtuner.registry import BUILDER TORCH_DTYPE_MAP = dict( fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') def parse_args(): parser = argparse.ArgumentParser(description='Chat with a HF model') parser.add_argument('config', help='config file name or path.') parser.add_argument('pth_model', help='pth model file') parser.add_argument('--save-path', default='./work_dirs/converted.pth', help='save path of converted pth') parser.add_argument( '--torch-dtype', default='fp16', choices=TORCH_DTYPE_MAP.keys(), help='Override the default `torch.dtype` and load the model under ' 'a specific `dtype`.') parser.add_argument( '--prompt-template', choices=PROMPT_TEMPLATE.keys(), default="internlm2_chat", help='Specify a prompt template') system_group = parser.add_mutually_exclusive_group() system_group.add_argument( '--system', default=None, help='Specify the system text') system_group.add_argument( '--system-template', choices=SYSTEM_TEMPLATE.keys(), default=None, help='Specify a system template') parser.add_argument( '--bits', type=int, choices=[4, 8, None], default=None, help='LLM bits') parser.add_argument( '--bot-name', type=str, default='BOT', help='Name for Bot') parser.add_argument( '--with-plugins', nargs='+', choices=['calculate', 'solve', 'search'], help='Specify plugins to use') parser.add_argument( '--no-streamer', action='store_true', help='Whether to with streamer') parser.add_argument( '--lagent', action='store_true', help='Whether to use lagent') parser.add_argument( '--stop-words', nargs='+', type=str, default=[], help='Stop words') parser.add_argument( '--offload-folder', default=None, help='The folder in which to offload the model weights (or where the ' 'model weights are already offloaded).') parser.add_argument( '--max-new-tokens', type=int, default=2048, help='Maximum number of new tokens allowed in generated text') parser.add_argument( '--temperature', type=float, default=0.1, help='The value used to modulate the next token probabilities.') parser.add_argument( '--top-k', type=int, default=40, help='The number of highest probability vocabulary tokens to ' 'keep for top-k-filtering.') parser.add_argument( '--top-p', type=float, default=0.75, help='If set to float < 1, only the smallest set of most probable ' 'tokens with probabilities that add up to top_p or higher are ' 'kept for generation.') parser.add_argument( '--repetition-penalty', type=float, default=1.0, help='The parameter for repetition penalty. 1.0 means no penalty.') parser.add_argument( '--seed', type=int, default=0, help='Random seed for reproducible text generation') args = parser.parse_args() return args def main(): args = parse_args() torch.manual_seed(args.seed) # parse config if not osp.isfile(args.config): try: args.config = cfgs_name_path[args.config] except KeyError: raise FileNotFoundError(f'Cannot find {args.config}') # load config cfg = Config.fromfile(args.config) # if args.cfg_options is not None: # cfg.merge_from_dict(args.cfg_options) model_name = cfg.model.type if isinstance(cfg.model.type, str) else cfg.model.type.__name__ if 'LLaVAModel' or 'OMG' in model_name: cfg.model.pretrained_pth = None model = BUILDER.build(cfg.model) backend = get_file_backend(args.pth_model) if isinstance(backend, PetrelBackend): from xtuner.utils.fileio import patch_fileio with patch_fileio(): state_dict = guess_load_checkpoint(args.pth_model) else: state_dict = guess_load_checkpoint(args.pth_model) model.load_state_dict(state_dict, strict=False) print(f'Load PTH model from {args.pth_model}') state_dict = model.state_dict() torch.save(state_dict, args.save_path) print('Save the converted pth to {}'.format(args.save_path)) return if __name__ == '__main__': main()