File size: 1,042 Bytes
e3e3f87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
import json
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, AutoConfig

model_name = "chatglm"

model_path = "../../../base-model/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)

model = model.quantize(4)
model = model.half().cuda()
model = model.eval()

input_fname = f"../QiaoBan/test_data.json"
output_fname = f"./{model_name}/result.json"

with open(input_fname, "r", encoding="utf-8") as fr, open(output_fname, "w", encoding="utf-8") as fw:
    for line in tqdm(fr.readlines(), ncols=80):
        sample = json.loads(line.strip())
        prompt = sample["prompt"]
        history = sample["history"]
        response, _ = model.chat(tokenizer, prompt, history=history)
        sample["prediction"] = response
        fw.write(json.dumps(sample, ensure_ascii=False))
        fw.write("\n")