lora_ckp / lora_checkpoints /data /baselines /inference_model.py
Ray121381's picture
1
e3e3f87
import os
import json
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, AutoConfig
model_name = "chatglm"
model_path = "../../../base-model/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
model = model.quantize(4)
model = model.half().cuda()
model = model.eval()
input_fname = f"../QiaoBan/test_data.json"
output_fname = f"./{model_name}/result.json"
with open(input_fname, "r", encoding="utf-8") as fr, open(output_fname, "w", encoding="utf-8") as fw:
for line in tqdm(fr.readlines(), ncols=80):
sample = json.loads(line.strip())
prompt = sample["prompt"]
history = sample["history"]
response, _ = model.chat(tokenizer, prompt, history=history)
sample["prediction"] = response
fw.write(json.dumps(sample, ensure_ascii=False))
fw.write("\n")