File size: 1,042 Bytes
e3e3f87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | import os
import json
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, AutoConfig
model_name = "chatglm"
model_path = "../../../base-model/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
model = model.quantize(4)
model = model.half().cuda()
model = model.eval()
input_fname = f"../QiaoBan/test_data.json"
output_fname = f"./{model_name}/result.json"
with open(input_fname, "r", encoding="utf-8") as fr, open(output_fname, "w", encoding="utf-8") as fw:
for line in tqdm(fr.readlines(), ncols=80):
sample = json.loads(line.strip())
prompt = sample["prompt"]
history = sample["history"]
response, _ = model.chat(tokenizer, prompt, history=history)
sample["prediction"] = response
fw.write(json.dumps(sample, ensure_ascii=False))
fw.write("\n")
|