File size: 11,418 Bytes
badcf3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
'''
用于统计500条三元组提取中的犯错类型
规则如下:
正确被真实三元组记录的24种关系如下:
出露于"、"位于"、"整合接触"、"不整合接触"、"假整合接触"、"断层接触"、"分布形态"、"大地构造位置"、
"地层区划"、"出露地层"、"岩性"、"厚度"、"面积"、"坐标"、"长度"、"含有"、"所属年代"、"行政区划"、
"发育"、"古生物"、"海拔"、"属于"、"吞噬"、"侵入"。
统计规则如下:
逐句子(500条)统计,犯错误类型
每个句子需记录如下信息:
24种关系的提取明细,可能遇到如下情况,
1、真实和预测三元组中都有对应关系,统计各自某一关系的个数,该条句子的该种关系记录为(预测样本中包含该关系的个数/真实样本中包含该关系的个数,大于1则记作1)
2、真实三元组中有对应关系,但是预测三元组中没有,该条句子的该种关系记录为0
3、真实三元组中没有对应关系,但是预测三元组中有,该条句子的该种关系记录为(预测样本中包含该关系的个数,取负值)
3扩充、实三元组中没有对应关系,但是预测三元组中有,但不属于24种关系,则新建一列'其他',该条句子的其他关系记录为(预测样本中包含其他系的个数,取负值)
每条句子的评价输出结果应该是长度为25的向量,前24个元素取值取决于预测和真实三元组中包含该关系的个数,最后一个元素取值取决于预测三元组中不属于24种关系的个数
真实和预测样本的格式如下:
[
{
"text": "诺日巴尕日保组原指灰色灰绿色厚层中-细粒岩屑长石砂岩长石石英砂岩长石砂岩偶夹粉砂岩,粘土岩及泥晶灰岩组成,仅见双壳类化石,与上覆九十道班组为连续沉积。",
"triple_list": [
[
"A,
"岩性",
"B"
],
[
"A",
"岩性",
"B"
]
]
},
...
rest 499
]
真实和与预测样本的顺序完全一致,可以使用索引定位
'''
import json
import matplotlib.pyplot as plt
import numpy as np
import set_font
set_font.set_font()
def evaluate_triples(gold_path, pred_path):
# 读取真实和预测的三元组数据
with open(gold_path, 'r', encoding='utf-8') as f:
gold_data = json.load(f)
with open(pred_path, 'r', encoding='utf-8') as f:
pred_data = json.load(f)
# 定义24种关系
relations = [
"出露于", "位于", "整合接触", "不整合接触", "假整合接触", "断层接触", "分布形态", "大地构造位置",
"地层区划", "出露地层", "岩性", "厚度", "面积", "坐标", "长度", "含有", "所属年代", "行政区划",
"发育", "古生物", "海拔", "属于", "吞噬", "侵入"
]
# 初始化错误统计和混乱程度统计
error_stats = [0] * 24
confusion_stats = [0] * 25
gold_counts = [0] * 24
pred_counts = [0] * 24
correct_counts = [0] * 24 # 新增:正确提取的统计
# 遍历每个句子
for gold, pred in zip(gold_data, pred_data):
gold_triples = gold['triple_list']
pred_triples = pred['triple_list']
# 统计每种关系的错误
for i, relation in enumerate(relations):
gold_count = sum(1 for triple in gold_triples if triple[1] == relation)
pred_count = sum(1 for triple in pred_triples if triple[1] == relation)
gold_counts[i] += gold_count
pred_counts[i] += pred_count
# 计算正确提取数
if gold_count > 0 and pred_count > 0:
correct_counts[i] += min(gold_count, pred_count)
if gold_count > 0 and pred_count > 0:
error_stats[i] += min(pred_count / gold_count, 1)
elif gold_count > 0 and pred_count == 0:
error_stats[i] += 0
elif gold_count == 0 and pred_count > 0:
confusion_stats[i] += pred_count
# 统计'其他'关系
other_count = sum(1 for triple in pred_triples if triple[1] not in relations)
confusion_stats[24] += other_count
for i in range(len(relations)):
print(f"{relations[i]}: 真实个数 = {gold_counts[i]}")# , 预测个数 = {pred_counts[i]}")
# 误用统计
# print(f"{relations[i]}: 误用个数 = {confusion_stats[i]}")
# 计算正确提取率
correct_rates = [correct / (gold if gold != 0 else 1) for correct, gold in zip(correct_counts, gold_counts)]
return error_stats, confusion_stats, gold_counts, pred_counts, correct_rates
def plot_results(relations, pred_error_stats, gold_error_stats, confusion_stats, correct_rates):
# 计算预测的error_stats和真实样本的error_stats的比值
ratios = [min(pred / (gold if gold != 0 else 1), 1) for pred, gold in zip(pred_error_stats, gold_error_stats)]
# 打印金样本和预测样本的每个分类匹配的得分,及其比值
# for i in range(len(relations)):
# # print(f"{relations[i]}: 真实得分 = {gold_error_stats[i]}, 预测得分 = {pred_error_stats[i]}, 比值 = {ratios[i]}")
# print(f"{relations[i]}: 匹配比值 = {ratios[i]}")
# 输出匹配分数和混乱程度
# print(f"匹配分数:{sum(pred_error_stats)}")
# print(f"混乱程度:{sum(confusion_stats)}")
# 输出正确提取率
print("\n正确提取率:")
for i in range(len(relations)):
print(f"{relations[i]}: {correct_rates[i]:.2f}")
# # 绘制预测/真实比值的条形图
# x = np.arange(len(relations))
# plt.figure(figsize=(12, 6))
# plt.bar(x, ratios, color='skyblue')
# plt.xticks(x, relations, rotation=45, ha='right')
# plt.ylabel('预测/真实 比值')
# plt.title('预测结果与真实结果的比值')
# plt.tight_layout()
# plt.show()
# # 绘制混乱程度的条形图
# x_confusion = np.arange(len(confusion_stats))
# plt.figure(figsize=(12, 6))
# plt.bar(x_confusion, confusion_stats, color='salmon')
# plt.xticks(x_confusion, relations + ['其他'], rotation=45, ha='right')
# plt.ylabel('混乱程度')
# plt.title('预测结果的混乱程度')
# plt.tight_layout()
# plt.show()
# # 绘制正确提取率的条形图
# x_correct = np.arange(len(correct_rates))
# plt.figure(figsize=(12, 6))
# plt.bar(x_correct, correct_rates, color='lightgreen')
# plt.xticks(x_correct, relations, rotation=45, ha='right')
# plt.ylabel('正确提取率')
# plt.title('正确提取率')
# plt.tight_layout()
# plt.show()
relations = [
"出露于", "位于", "整合接触", "不整合接触", "假整合接触", "断层接触", "分布形态", "大地构造位置",
"地层区划", "出露地层", "岩性", "厚度", "面积", "坐标", "长度", "含有", "所属年代", "行政区划",
"发育", "古生物", "海拔", "属于", "吞噬", "侵入"
]
# 使用示例
gold_path = './data/GT_500.json'
model_paths = [
# # gpt-3.5
'./data/GT_500.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/gpt-3.5-turbo.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/gpt-4o.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/gemini-1.5-pro-002.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/claude-3-5-haiku-20241022.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/deepseek-ai/deepseek-V3.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/deepseek-ai/deepseek-R1.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/meta-llama/Meta-Llama-3.1-405B-Instruct.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/Qwen/Qwen2.5-72B-Instruct.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/two_shot/deepseek-ai/deepseek-V3.json',
# 'F:/GeoLLM/output/output_result/Task1/knn/three_shot/deepseek-ai/deepseek-R1.json',
# 'F:/GeoLLM/output/output_result/Task1/Knowledge-guided/one_shot/deepseek-ai/deepseek-R1.json'
# 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/deepseek-ai/deepseek-V3.json'
# 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/gpt-3.5-turbo.json',
# 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/gpt-3.5-turbo_0407.json'
]
for model_path in model_paths:
pred_path = model_path
print(pred_path)
# error_stats, confusion_stats, gold_counts, pred_counts = evaluate_triples(gold_path, pred_path)
# # 打印预测样本的每个分类匹配的个数
# for i in range(len(relations)):
# print(relations[i], gold_counts[i], pred_counts[i])
# # 打印金样本和预测样本的每个分类匹配的得分
# for i in range(len(relations)):
# print(relations[i], error_stats[i])
# plot_results(relations, error_stats, gold_counts, confusion_stats)
# 计算预测样本的得分
pred_error_stats, confusion_stats, _, _, correct_rates = evaluate_triples(gold_path, pred_path)
# 计算真实样本的得分
gold_error_stats, _, _, _, _ = evaluate_triples(gold_path, gold_path)
plot_results(relations, pred_error_stats, gold_error_stats, confusion_stats, correct_rates)
# # 预测样本
# model_paths = [
# # # gpt-3.5
# './data/GT_500.json',
# # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/deepseek-ai/deepseek-V3.json', # 零样本
# # 'F:/GeoLLM/output/output_result/Task1/nomal/one_shot/gpt-3p5-turbo.json', # 单样本
# # 'F:/GeoLLM/output/output_result/Task1/nomal/two_shot/gpt-3p5-turbo.json', # 双样本
# # 'F:/GeoLLM/output/output_result/Task1/nomal/three_shot/gpt-3p5-turbo.json', # 三样本
# # 'F:/GeoLLM/output/output_result/Task1/knn/one_shot/gpt-3p5-turbo.json', # KNN单样本
# # 'F:/GeoLLM/output/output_result/Task1/knn/two_shot/gpt-3p5-turbo.json', # KNN双样本
# # 'F:/GeoLLM/output/output_result/Task1/knn/three_shot/gpt-3p5-turbo.json', # KNN三样本
# # 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/deepseek-ai/deepseek-V3.json', # 知识引导单样本
# ]
# for model_path in model_paths:
# # 假设你想要选择第一个路径作为预测样本路径
# pred_path = model_path
# # 使用示例
# error_stats, confusion_stats = evaluate_triples(gold_path, pred_path)
# print("匹配统计结果:\n")
# for i in range(len(relations)):
# print(relations[i], error_stats[i])
# # 使用内置的 sum() 函数计算总和
# print("匹配分数:", sum(error_stats))
# print('--------------------------------')
# print("误用统计:\n")
# for i in range(len(relations)):
# print(relations[i], confusion_stats[i])
# print("混乱程度总和:", sum(confusion_stats)) |