| import argparse | |
| import torch | |
| from internvl.model.internvl_chat import InternVLChatModel | |
| from transformers import AutoModel, AutoTokenizer | |
| argparse = argparse.ArgumentParser() | |
| argparse.add_argument('model_path', type=str, default='') | |
| argparse.add_argument('llm_path', type=str, default='') | |
| args = argparse.parse_args() | |
| if args.model_path[-1] == '/': | |
| args.model_path = args.model_path[:-1] | |
| model = InternVLChatModel.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) | |
| llm = AutoModel.from_pretrained( | |
| args.llm_path, trust_remote_code=True, torch_dtype=torch.bfloat16) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| args.llm_path, trust_remote_code=True) | |
| model.language_model = llm | |
| model.config.llm_config = llm.config | |
| model.to(torch.bfloat16) | |
| output_path = args.model_path + '_replace_llm' | |
| model.save_pretrained(output_path) | |
| tokenizer.save_pretrained(output_path) | |
| print('finished') | |