import os import json import torch from tqdm import tqdm from transformers import AutoTokenizer, AutoModel, AutoConfig model_name = "chatglm" model_path = "../../../base-model/chatglm-6b" tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True) model = model.quantize(4) model = model.half().cuda() model = model.eval() input_fname = f"../QiaoBan/test_data.json" output_fname = f"./{model_name}/result.json" with open(input_fname, "r", encoding="utf-8") as fr, open(output_fname, "w", encoding="utf-8") as fw: for line in tqdm(fr.readlines(), ncols=80): sample = json.loads(line.strip()) prompt = sample["prompt"] history = sample["history"] response, _ = model.chat(tokenizer, prompt, history=history) sample["prediction"] = response fw.write(json.dumps(sample, ensure_ascii=False)) fw.write("\n")