lora_ckp / lora_checkpoints /test_inference.py
Ray121381's picture
1
e3e3f87
import os
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
model_path = "../base-model/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
config.prefix_n_experts = 2
config.prefix_cur_expert = -1
config.expert_weights = [0.4, 0.6]
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
# 此处使用你的 ptuning 工作目录
prefix_state_dict = torch.load(
os.path.join("./checkpoints/checkpoint-3000", "pytorch_model.bin")
)
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
new_prefix_state_dict[k[len("transformer.prefix_encoder.") :]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
# V100 机型上可以不进行量化
# print(f"Quantized to 4 bit")
model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()
# response, history = model.chat(tokenizer, "你好,我今天心情不好,你能陪陪我吗?", history=[])
# print(response)
# response, history = model.chat(tokenizer, "谢谢你!", history=history)
# print(response)
response = ""
history = []
while True:
user_prompt = input(">>>>>>>> USER: ")
response, history = model.chat(tokenizer, user_prompt, history=history)
print(f">>>>>>>> Assistant: {response}")
print("\n")