DenseLabelDev / projects /omg_llava /tools /convert_deepspeed2pth.py
zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
PROMPT_TEMPLATE, SYSTEM_TEMPLATE)
import argparse
import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.fileio import PetrelBackend, get_file_backend
from xtuner.configs import cfgs_name_path
from xtuner.model.utils import guess_load_checkpoint
from xtuner.registry import BUILDER
TORCH_DTYPE_MAP = dict(
fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
def parse_args():
parser = argparse.ArgumentParser(description='Chat with a HF model')
parser.add_argument('config', help='config file name or path.')
parser.add_argument('pth_model', help='pth model file')
parser.add_argument('--save-path', default='./work_dirs/converted.pth', help='save path of converted pth')
parser.add_argument(
'--torch-dtype',
default='fp16',
choices=TORCH_DTYPE_MAP.keys(),
help='Override the default `torch.dtype` and load the model under '
'a specific `dtype`.')
parser.add_argument(
'--prompt-template',
choices=PROMPT_TEMPLATE.keys(),
default="internlm2_chat",
help='Specify a prompt template')
system_group = parser.add_mutually_exclusive_group()
system_group.add_argument(
'--system', default=None, help='Specify the system text')
system_group.add_argument(
'--system-template',
choices=SYSTEM_TEMPLATE.keys(),
default=None,
help='Specify a system template')
parser.add_argument(
'--bits',
type=int,
choices=[4, 8, None],
default=None,
help='LLM bits')
parser.add_argument(
'--bot-name', type=str, default='BOT', help='Name for Bot')
parser.add_argument(
'--with-plugins',
nargs='+',
choices=['calculate', 'solve', 'search'],
help='Specify plugins to use')
parser.add_argument(
'--no-streamer', action='store_true', help='Whether to with streamer')
parser.add_argument(
'--lagent', action='store_true', help='Whether to use lagent')
parser.add_argument(
'--stop-words', nargs='+', type=str, default=[], help='Stop words')
parser.add_argument(
'--offload-folder',
default=None,
help='The folder in which to offload the model weights (or where the '
'model weights are already offloaded).')
parser.add_argument(
'--max-new-tokens',
type=int,
default=2048,
help='Maximum number of new tokens allowed in generated text')
parser.add_argument(
'--temperature',
type=float,
default=0.1,
help='The value used to modulate the next token probabilities.')
parser.add_argument(
'--top-k',
type=int,
default=40,
help='The number of highest probability vocabulary tokens to '
'keep for top-k-filtering.')
parser.add_argument(
'--top-p',
type=float,
default=0.75,
help='If set to float < 1, only the smallest set of most probable '
'tokens with probabilities that add up to top_p or higher are '
'kept for generation.')
parser.add_argument(
'--repetition-penalty',
type=float,
default=1.0,
help='The parameter for repetition penalty. 1.0 means no penalty.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Random seed for reproducible text generation')
args = parser.parse_args()
return args
def main():
args = parse_args()
torch.manual_seed(args.seed)
# parse config
if not osp.isfile(args.config):
try:
args.config = cfgs_name_path[args.config]
except KeyError:
raise FileNotFoundError(f'Cannot find {args.config}')
# load config
cfg = Config.fromfile(args.config)
# if args.cfg_options is not None:
# cfg.merge_from_dict(args.cfg_options)
model_name = cfg.model.type if isinstance(cfg.model.type,
str) else cfg.model.type.__name__
if 'LLaVAModel' or 'OMG' in model_name:
cfg.model.pretrained_pth = None
model = BUILDER.build(cfg.model)
backend = get_file_backend(args.pth_model)
if isinstance(backend, PetrelBackend):
from xtuner.utils.fileio import patch_fileio
with patch_fileio():
state_dict = guess_load_checkpoint(args.pth_model)
else:
state_dict = guess_load_checkpoint(args.pth_model)
model.load_state_dict(state_dict, strict=False)
print(f'Load PTH model from {args.pth_model}')
state_dict = model.state_dict()
torch.save(state_dict, args.save_path)
print('Save the converted pth to {}'.format(args.save_path))
return
if __name__ == '__main__':
main()