File size: 11,418 Bytes
badcf3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
'''

    用于统计500条三元组提取中的犯错类型

    规则如下:

    正确被真实三元组记录的24种关系如下:

    出露于"、"位于"、"整合接触"、"不整合接触"、"假整合接触"、"断层接触"、"分布形态"、"大地构造位置"、

    "地层区划"、"出露地层"、"岩性"、"厚度"、"面积"、"坐标"、"长度"、"含有"、"所属年代"、"行政区划"、

    "发育"、"古生物"、"海拔"、"属于"、"吞噬"、"侵入"。

    统计规则如下:

    逐句子(500条)统计,犯错误类型

    每个句子需记录如下信息:

    24种关系的提取明细,可能遇到如下情况,

    1、真实和预测三元组中都有对应关系,统计各自某一关系的个数,该条句子的该种关系记录为(预测样本中包含该关系的个数/真实样本中包含该关系的个数,大于1则记作1)

    2、真实三元组中有对应关系,但是预测三元组中没有,该条句子的该种关系记录为0

    3、真实三元组中没有对应关系,但是预测三元组中有,该条句子的该种关系记录为(预测样本中包含该关系的个数,取负值)

    3扩充、实三元组中没有对应关系,但是预测三元组中有,但不属于24种关系,则新建一列'其他',该条句子的其他关系记录为(预测样本中包含其他系的个数,取负值)

    每条句子的评价输出结果应该是长度为25的向量,前24个元素取值取决于预测和真实三元组中包含该关系的个数,最后一个元素取值取决于预测三元组中不属于24种关系的个数

    真实和预测样本的格式如下:

    [

    {

        "text": "诺日巴尕日保组原指灰色灰绿色厚层中-细粒岩屑长石砂岩长石石英砂岩长石砂岩偶夹粉砂岩,粘土岩及泥晶灰岩组成,仅见双壳类化石,与上覆九十道班组为连续沉积。",

        "triple_list": [

            [

                "A,

                "岩性",

                "B"

            ],

            [

                "A",

                "岩性",

                "B"

            ]

        ]

    },

    ...

    rest 499

    ]

    真实和与预测样本的顺序完全一致,可以使用索引定位

'''


import json
import matplotlib.pyplot as plt
import numpy as np
import set_font
set_font.set_font()
def evaluate_triples(gold_path, pred_path):
    # 读取真实和预测的三元组数据
    with open(gold_path, 'r', encoding='utf-8') as f:
        gold_data = json.load(f)
    
    with open(pred_path, 'r', encoding='utf-8') as f:
        pred_data = json.load(f)
    
    # 定义24种关系
    relations = [
        "出露于", "位于", "整合接触", "不整合接触", "假整合接触", "断层接触", "分布形态", "大地构造位置",
        "地层区划", "出露地层", "岩性", "厚度", "面积", "坐标", "长度", "含有", "所属年代", "行政区划",
        "发育", "古生物", "海拔", "属于", "吞噬", "侵入"
    ]
    
    # 初始化错误统计和混乱程度统计
    error_stats = [0] * 24
    confusion_stats = [0] * 25
    gold_counts = [0] * 24
    pred_counts = [0] * 24
    correct_counts = [0] * 24  # 新增:正确提取的统计
    
    # 遍历每个句子
    for gold, pred in zip(gold_data, pred_data):
        gold_triples = gold['triple_list']
        pred_triples = pred['triple_list']
        
        # 统计每种关系的错误
        for i, relation in enumerate(relations):
            gold_count = sum(1 for triple in gold_triples if triple[1] == relation)
            pred_count = sum(1 for triple in pred_triples if triple[1] == relation)
            
            gold_counts[i] += gold_count
            pred_counts[i] += pred_count
            
            # 计算正确提取数
            if gold_count > 0 and pred_count > 0:
                correct_counts[i] += min(gold_count, pred_count)
            
            if gold_count > 0 and pred_count > 0:
                error_stats[i] += min(pred_count / gold_count, 1)
            elif gold_count > 0 and pred_count == 0:
                error_stats[i] += 0
            elif gold_count == 0 and pred_count > 0:
                confusion_stats[i] += pred_count
        
        # 统计'其他'关系
        other_count = sum(1 for triple in pred_triples if triple[1] not in relations)
        confusion_stats[24] += other_count
        for i in range(len(relations)):
            print(f"{relations[i]}: 真实个数 = {gold_counts[i]}")# , 预测个数 = {pred_counts[i]}")
            # 误用统计
            # print(f"{relations[i]}:  误用个数 = {confusion_stats[i]}")
        
    # 计算正确提取率
    correct_rates = [correct / (gold if gold != 0 else 1) for correct, gold in zip(correct_counts, gold_counts)]
    
    return error_stats, confusion_stats, gold_counts, pred_counts, correct_rates

def plot_results(relations, pred_error_stats, gold_error_stats, confusion_stats, correct_rates):
    # 计算预测的error_stats和真实样本的error_stats的比值
    ratios = [min(pred / (gold if gold != 0 else 1), 1) for pred, gold in zip(pred_error_stats, gold_error_stats)]
    
    # 打印金样本和预测样本的每个分类匹配的得分,及其比值
    # for i in range(len(relations)):
    #     # print(f"{relations[i]}: 真实得分 = {gold_error_stats[i]}, 预测得分 = {pred_error_stats[i]}, 比值 = {ratios[i]}")
    #     print(f"{relations[i]}: 匹配比值 = {ratios[i]}")
    # 输出匹配分数和混乱程度
    # print(f"匹配分数:{sum(pred_error_stats)}")
    # print(f"混乱程度:{sum(confusion_stats)}")

    # 输出正确提取率
    print("\n正确提取率:")
    for i in range(len(relations)):
        print(f"{relations[i]}: {correct_rates[i]:.2f}")

    # # 绘制预测/真实比值的条形图
    # x = np.arange(len(relations))
    # plt.figure(figsize=(12, 6))
    # plt.bar(x, ratios, color='skyblue')
    # plt.xticks(x, relations, rotation=45, ha='right')
    # plt.ylabel('预测/真实 比值')
    # plt.title('预测结果与真实结果的比值')
    # plt.tight_layout()
    # plt.show()
    
    # # 绘制混乱程度的条形图
    # x_confusion = np.arange(len(confusion_stats))
    # plt.figure(figsize=(12, 6))
    # plt.bar(x_confusion, confusion_stats, color='salmon')
    # plt.xticks(x_confusion, relations + ['其他'], rotation=45, ha='right')
    # plt.ylabel('混乱程度')
    # plt.title('预测结果的混乱程度')
    # plt.tight_layout()
    # plt.show()

    # # 绘制正确提取率的条形图
    # x_correct = np.arange(len(correct_rates))
    # plt.figure(figsize=(12, 6))
    # plt.bar(x_correct, correct_rates, color='lightgreen')
    # plt.xticks(x_correct, relations, rotation=45, ha='right')
    # plt.ylabel('正确提取率')
    # plt.title('正确提取率')
    # plt.tight_layout()
    # plt.show()
    
relations = [
    "出露于", "位于", "整合接触", "不整合接触", "假整合接触", "断层接触", "分布形态", "大地构造位置",
    "地层区划", "出露地层", "岩性", "厚度", "面积", "坐标", "长度", "含有", "所属年代", "行政区划",
    "发育", "古生物", "海拔", "属于", "吞噬", "侵入"
]
# 使用示例
gold_path = './data/GT_500.json'
model_paths = [
    # # gpt-3.5
    './data/GT_500.json',
    # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/gpt-3.5-turbo.json',
    # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/gpt-4o.json',
    # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/gemini-1.5-pro-002.json',
    # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/claude-3-5-haiku-20241022.json',
    # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/deepseek-ai/deepseek-V3.json',
    # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/deepseek-ai/deepseek-R1.json',
    # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/meta-llama/Meta-Llama-3.1-405B-Instruct.json',
    # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/Qwen/Qwen2.5-72B-Instruct.json',
    # 'F:/GeoLLM/output/output_result/Task1/nomal/two_shot/deepseek-ai/deepseek-V3.json',
    # 'F:/GeoLLM/output/output_result/Task1/knn/three_shot/deepseek-ai/deepseek-R1.json',
    # 'F:/GeoLLM/output/output_result/Task1/Knowledge-guided/one_shot/deepseek-ai/deepseek-R1.json'
    # 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/deepseek-ai/deepseek-V3.json'


    # 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/gpt-3.5-turbo.json',
    # 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/gpt-3.5-turbo_0407.json'    

]

for model_path in model_paths:
    pred_path = model_path
    print(pred_path)
    # error_stats, confusion_stats, gold_counts, pred_counts = evaluate_triples(gold_path, pred_path)
    # # 打印预测样本的每个分类匹配的个数
    # for i in range(len(relations)):
    #     print(relations[i], gold_counts[i], pred_counts[i])
    # # 打印金样本和预测样本的每个分类匹配的得分
    # for i in range(len(relations)):
    #     print(relations[i], error_stats[i])
    # plot_results(relations, error_stats, gold_counts, confusion_stats)
    # 计算预测样本的得分
    pred_error_stats, confusion_stats, _, _, correct_rates = evaluate_triples(gold_path, pred_path)

    # 计算真实样本的得分
    gold_error_stats, _, _, _, _ = evaluate_triples(gold_path, gold_path)

    plot_results(relations, pred_error_stats, gold_error_stats, confusion_stats, correct_rates)


# # 预测样本
# model_paths = [
#     # # gpt-3.5
#     './data/GT_500.json',
#     # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/deepseek-ai/deepseek-V3.json',        # 零样本
#     # 'F:/GeoLLM/output/output_result/Task1/nomal/one_shot/gpt-3p5-turbo.json',        # 单样本
#     # 'F:/GeoLLM/output/output_result/Task1/nomal/two_shot/gpt-3p5-turbo.json',        # 双样本
#     # 'F:/GeoLLM/output/output_result/Task1/nomal/three_shot/gpt-3p5-turbo.json',        # 三样本
#     # 'F:/GeoLLM/output/output_result/Task1/knn/one_shot/gpt-3p5-turbo.json',        # KNN单样本    
#     # 'F:/GeoLLM/output/output_result/Task1/knn/two_shot/gpt-3p5-turbo.json',        # KNN双样本
#     # 'F:/GeoLLM/output/output_result/Task1/knn/three_shot/gpt-3p5-turbo.json',        # KNN三样本
#     # 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/deepseek-ai/deepseek-V3.json',        # 知识引导单样本
# ]
# for model_path in model_paths:
#     # 假设你想要选择第一个路径作为预测样本路径
#     pred_path = model_path
#     # 使用示例
#     error_stats, confusion_stats = evaluate_triples(gold_path, pred_path)
#     print("匹配统计结果:\n")
#     for i in range(len(relations)):
#         print(relations[i], error_stats[i])
#     # 使用内置的 sum() 函数计算总和
#     print("匹配分数:", sum(error_stats))  
#     print('--------------------------------')
#     print("误用统计:\n")
#     for i in range(len(relations)):
#         print(relations[i], confusion_stats[i])
#     print("混乱程度总和:", sum(confusion_stats))