import os import torch from transformers import AutoTokenizer, AutoModel, AutoConfig model_path = "../base-model/chatglm-6b" tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128) config.prefix_n_experts = 2 config.prefix_cur_expert = -1 config.expert_weights = [0.4, 0.6] model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True) # 此处使用你的 ptuning 工作目录 prefix_state_dict = torch.load( os.path.join("./checkpoints/checkpoint-3000", "pytorch_model.bin") ) new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): new_prefix_state_dict[k[len("transformer.prefix_encoder.") :]] = v model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) # V100 机型上可以不进行量化 # print(f"Quantized to 4 bit") model = model.quantize(4) model = model.half().cuda() model.transformer.prefix_encoder.float() model = model.eval() # response, history = model.chat(tokenizer, "你好,我今天心情不好,你能陪陪我吗?", history=[]) # print(response) # response, history = model.chat(tokenizer, "谢谢你!", history=history) # print(response) response = "" history = [] while True: user_prompt = input(">>>>>>>> USER: ") response, history = model.chat(tokenizer, user_prompt, history=history) print(f">>>>>>>> Assistant: {response}") print("\n")