File size: 1,468 Bytes
e3e3f87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig

model_path = "../base-model/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
config.prefix_n_experts = 2
config.prefix_cur_expert = -1
config.expert_weights = [0.4, 0.6]

model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)

# 此处使用你的 ptuning 工作目录
prefix_state_dict = torch.load(
    os.path.join("./checkpoints/checkpoint-3000", "pytorch_model.bin")
)
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    new_prefix_state_dict[k[len("transformer.prefix_encoder.") :]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

# V100 机型上可以不进行量化
# print(f"Quantized to 4 bit")
model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()

# response, history = model.chat(tokenizer, "你好,我今天心情不好,你能陪陪我吗?", history=[])
# print(response)

# response, history = model.chat(tokenizer, "谢谢你!", history=history)
# print(response)

response = ""
history = []

while True:
    user_prompt = input(">>>>>>>> USER:  ")
    response, history = model.chat(tokenizer, user_prompt, history=history)
    print(f">>>>>>>> Assistant:  {response}")
    print("\n")