| import os |
| import json |
| import torch |
| from tqdm import tqdm |
| from transformers import AutoTokenizer, AutoModel, AutoConfig |
|
|
| model_name = "chatglm" |
|
|
| model_path = "../../../base-model/chatglm-6b" |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) |
| model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True) |
|
|
| model = model.quantize(4) |
| model = model.half().cuda() |
| model = model.eval() |
|
|
| input_fname = f"../QiaoBan/test_data.json" |
| output_fname = f"./{model_name}/result.json" |
|
|
| with open(input_fname, "r", encoding="utf-8") as fr, open(output_fname, "w", encoding="utf-8") as fw: |
| for line in tqdm(fr.readlines(), ncols=80): |
| sample = json.loads(line.strip()) |
| prompt = sample["prompt"] |
| history = sample["history"] |
| response, _ = model.chat(tokenizer, prompt, history=history) |
| sample["prediction"] = response |
| fw.write(json.dumps(sample, ensure_ascii=False)) |
| fw.write("\n") |
|
|