File size: 5,500 Bytes
0c51b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import argparse
import os
from typing import Any, Iterable, List, Tuple

import jsonlines
from scipy.stats import spearmanr


def read_data(data_dir: str) -> Tuple[Iterable[Any], Iterable[Any]]:
    with jsonlines.open(
        os.path.join(data_dir, "openai_log_attribution.jsonl"), "r"
    ) as reader:
        prompting_dataset = list(reader)

    with jsonlines.open(
        os.path.join(data_dir, "human_log_attribution.jsonl"), "r"
    ) as reader:
        human_dataset = list(reader)
    return prompting_dataset, human_dataset


def hard_code_key(attributed_utterances: Any) -> Any:
    new_attributed_utterances = {}
    for key in attributed_utterances:
        utterance_num = int(key.split(" ")[1])
        new_utterance_num = utterance_num + 1
        new_key = key.replace(str(utterance_num), str(new_utterance_num))
        new_attributed_utterances[new_key] = attributed_utterances[key]
    return new_attributed_utterances


def build_paired_scores(
    human_attributed_utterances: Any,
    prompt_attributed_utterances: Any,
    average: bool = False,
    annotator: int = 0,
) -> List[Tuple[int, int]]:
    paired_scores = []
    for key in human_attributed_utterances:
        human_scores = human_attributed_utterances[key][-2]
        prompt_score = prompt_attributed_utterances[key][-1]
        if isinstance(human_scores, dict) and prompt_score != -1:
            sorted_human_scores = sorted(
                human_scores.items(), key=lambda x: x[0]
            )
            ann0, ann1 = sorted_human_scores[0][1], sorted_human_scores[1][1]
            if average:
                human_score = (ann0 + ann1) / 2
            else:
                human_score = ann0 if annotator == 0 else ann1
            paired_scores.append((human_score, prompt_score))
    return paired_scores


def main(data_dir: str, average: bool, annotator: int) -> None:
    prompting_dataset, human_dataset = read_data(data_dir)
    paired_scores_dataset = []
    for human_data in human_dataset:
        for prompt_data in prompting_dataset:
            if (
                human_data["episode_id"] == prompt_data["episode_id"]
                and human_data["agent"] == prompt_data["agent"]
            ):
                human_attributed_utterances = human_data[
                    "attributed_utterances"
                ]
                prompt_attributed_utterances = prompt_data[
                    "attributed_utterances"
                ]
                paired_scores = build_paired_scores(
                    human_attributed_utterances,
                    prompt_attributed_utterances,
                    average=average,
                    annotator=annotator,
                )
                paired_scores_dataset += paired_scores
                break
    human_scores = [score[0] for score in paired_scores_dataset]
    prompt_scores = [score[1] for score in paired_scores_dataset]
    spearman_corr, _ = spearmanr(human_scores, prompt_scores)
    agreement_rate = len(
        [1 for score in paired_scores_dataset if score[0] == score[1]]
    ) / len(paired_scores_dataset)
    avg_diff = sum(
        [abs(score[0] - score[1]) for score in paired_scores_dataset]
    ) / len(paired_scores_dataset)
    print("average difference: {}".format(avg_diff))
    print("spearman correlation: {}".format(spearman_corr))
    print("exact match: {}".format(agreement_rate))

    human_3_scores = [1 if score == 3 else 0 for score in human_scores]
    prompt_3_scores = [1 if score == 3 else 0 for score in prompt_scores]
    # calculate the accuracy
    accuracy = sum(
        [
            1
            for i in range(len(human_3_scores))
            if human_3_scores[i] == prompt_3_scores[i]
        ]
    ) / len(human_3_scores)
    print("Accuracy: {}".format(accuracy))
    # calculate the F1 score
    tp = sum(
        [
            1
            for i in range(len(human_3_scores))
            if human_3_scores[i] == 1 and prompt_3_scores[i] == 1
        ]
    )
    fp = sum(
        [
            1
            for i in range(len(human_3_scores))
            if human_3_scores[i] == 0 and prompt_3_scores[i] == 1
        ]
    )
    fn = sum(
        [
            1
            for i in range(len(human_3_scores))
            if human_3_scores[i] == 1 and prompt_3_scores[i] == 0
        ]
    )
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / (precision + recall)
    print("Precision: {}".format(precision))
    print("Recall: {}".format(recall))
    print("F1 score: {}".format(f1))
    # print(f"{avg_diff} {spearman_corr} {agreement_rate} {agreement_rate} {f1} {precision} {recall}")
    print(
        "{:.3f}, {:.3f}, {:.3f}, {:.3f}, {:.3f}, {:.3f}".format(
            avg_diff, spearman_corr, agreement_rate, f1, precision, recall
        )
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process some integers.")
    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="Directory containing data files",
    )
    parser.add_argument(
        "--average",
        action="store_true",
        help="Whether to average the human scores",
    )
    parser.add_argument(
        "--annotator",
        type=int,
        required=False,
        help="Which human annotator to use",
    )

    args = parser.parse_args()
    print(args.data_dir, args.average, args.annotator)
    main(args.data_dir, args.average, args.annotator)