File size: 1,748 Bytes
e3e3f87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import json
import jieba
from rouge_chinese import Rouge
import numpy as np
from evaluate import load
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    score_dict = {
        "rouge-1": [],
        "rouge-2": [],
        "rouge-l": [],
        "bleu-1": [],
        "bleu-2": [],
        "bleu-3": [],
        "bleu-4": [],
    }
    for pred, label in zip(preds, labels):
        hypothesis = list(jieba.cut(pred))
        reference = list(jieba.cut(label))
        rouge = Rouge()
        scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
        result = scores[0]

        for k, v in result.items():
            score_dict[k].append(round(v["f"] * 100, 4))
        bleu_weights = [
            [1.0, 0.0, 0.0, 0.0],
            [0.5, 0.5, 0.0, 0.0],
            [1.0 / 3] * 3 + [0.0],
            [0.25] * 4,
        ]
        for ibw, bw in enumerate(bleu_weights):
            bleu_score = sentence_bleu(
                [list(label)],
                list(pred),
                weights=bw,
                smoothing_function=SmoothingFunction().method3,
            )
            score_dict[f"bleu-{ibw+1}"].append(round(bleu_score * 100, 4))

    for k, v in score_dict.items():
        score_dict[k] = float(np.mean(v))
    return score_dict

# model_name = "chatglm"
model_name = "chatgpt"
fname = f"./{model_name}/result.json"

references, predictions = [], []
with open(fname, "r") as fr:
    for line in fr.readlines():
        sample = json.loads(line.strip())
        references.append(sample["response"])
        predictions.append(sample["prediction"])

result = compute_metrics([predictions, references])
print(result)