|
|
|
|
|
import torch |
|
|
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, |
|
|
PROMPT_TEMPLATE, SYSTEM_TEMPLATE) |
|
|
|
|
|
import argparse |
|
|
import os.path as osp |
|
|
|
|
|
from mmengine.config import Config, DictAction |
|
|
from mmengine.fileio import PetrelBackend, get_file_backend |
|
|
|
|
|
from xtuner.configs import cfgs_name_path |
|
|
from xtuner.model.utils import guess_load_checkpoint |
|
|
from xtuner.registry import BUILDER |
|
|
|
|
|
TORCH_DTYPE_MAP = dict( |
|
|
fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description='Chat with a HF model') |
|
|
parser.add_argument('config', help='config file name or path.') |
|
|
parser.add_argument('pth_model', help='pth model file') |
|
|
|
|
|
parser.add_argument('--save-path', default='./work_dirs/converted.pth', help='save path of converted pth') |
|
|
parser.add_argument( |
|
|
'--torch-dtype', |
|
|
default='fp16', |
|
|
choices=TORCH_DTYPE_MAP.keys(), |
|
|
help='Override the default `torch.dtype` and load the model under ' |
|
|
'a specific `dtype`.') |
|
|
parser.add_argument( |
|
|
'--prompt-template', |
|
|
choices=PROMPT_TEMPLATE.keys(), |
|
|
default="internlm2_chat", |
|
|
help='Specify a prompt template') |
|
|
system_group = parser.add_mutually_exclusive_group() |
|
|
system_group.add_argument( |
|
|
'--system', default=None, help='Specify the system text') |
|
|
system_group.add_argument( |
|
|
'--system-template', |
|
|
choices=SYSTEM_TEMPLATE.keys(), |
|
|
default=None, |
|
|
help='Specify a system template') |
|
|
parser.add_argument( |
|
|
'--bits', |
|
|
type=int, |
|
|
choices=[4, 8, None], |
|
|
default=None, |
|
|
help='LLM bits') |
|
|
parser.add_argument( |
|
|
'--bot-name', type=str, default='BOT', help='Name for Bot') |
|
|
parser.add_argument( |
|
|
'--with-plugins', |
|
|
nargs='+', |
|
|
choices=['calculate', 'solve', 'search'], |
|
|
help='Specify plugins to use') |
|
|
parser.add_argument( |
|
|
'--no-streamer', action='store_true', help='Whether to with streamer') |
|
|
parser.add_argument( |
|
|
'--lagent', action='store_true', help='Whether to use lagent') |
|
|
parser.add_argument( |
|
|
'--stop-words', nargs='+', type=str, default=[], help='Stop words') |
|
|
parser.add_argument( |
|
|
'--offload-folder', |
|
|
default=None, |
|
|
help='The folder in which to offload the model weights (or where the ' |
|
|
'model weights are already offloaded).') |
|
|
parser.add_argument( |
|
|
'--max-new-tokens', |
|
|
type=int, |
|
|
default=2048, |
|
|
help='Maximum number of new tokens allowed in generated text') |
|
|
parser.add_argument( |
|
|
'--temperature', |
|
|
type=float, |
|
|
default=0.1, |
|
|
help='The value used to modulate the next token probabilities.') |
|
|
parser.add_argument( |
|
|
'--top-k', |
|
|
type=int, |
|
|
default=40, |
|
|
help='The number of highest probability vocabulary tokens to ' |
|
|
'keep for top-k-filtering.') |
|
|
parser.add_argument( |
|
|
'--top-p', |
|
|
type=float, |
|
|
default=0.75, |
|
|
help='If set to float < 1, only the smallest set of most probable ' |
|
|
'tokens with probabilities that add up to top_p or higher are ' |
|
|
'kept for generation.') |
|
|
parser.add_argument( |
|
|
'--repetition-penalty', |
|
|
type=float, |
|
|
default=1.0, |
|
|
help='The parameter for repetition penalty. 1.0 means no penalty.') |
|
|
parser.add_argument( |
|
|
'--seed', |
|
|
type=int, |
|
|
default=0, |
|
|
help='Random seed for reproducible text generation') |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
|
|
|
if not osp.isfile(args.config): |
|
|
try: |
|
|
args.config = cfgs_name_path[args.config] |
|
|
except KeyError: |
|
|
raise FileNotFoundError(f'Cannot find {args.config}') |
|
|
|
|
|
|
|
|
cfg = Config.fromfile(args.config) |
|
|
|
|
|
|
|
|
|
|
|
model_name = cfg.model.type if isinstance(cfg.model.type, |
|
|
str) else cfg.model.type.__name__ |
|
|
if 'LLaVAModel' or 'OMG' in model_name: |
|
|
cfg.model.pretrained_pth = None |
|
|
|
|
|
model = BUILDER.build(cfg.model) |
|
|
|
|
|
backend = get_file_backend(args.pth_model) |
|
|
if isinstance(backend, PetrelBackend): |
|
|
from xtuner.utils.fileio import patch_fileio |
|
|
with patch_fileio(): |
|
|
state_dict = guess_load_checkpoint(args.pth_model) |
|
|
else: |
|
|
state_dict = guess_load_checkpoint(args.pth_model) |
|
|
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
print(f'Load PTH model from {args.pth_model}') |
|
|
|
|
|
state_dict = model.state_dict() |
|
|
torch.save(state_dict, args.save_path) |
|
|
print('Save the converted pth to {}'.format(args.save_path)) |
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|