GeoLLM / Mistake_Statistics.py
Pengfa Li
Upload folder using huggingface_hub
badcf3c verified
'''
用于统计500条三元组提取中的犯错类型
规则如下:
正确被真实三元组记录的24种关系如下:
出露于"、"位于"、"整合接触"、"不整合接触"、"假整合接触"、"断层接触"、"分布形态"、"大地构造位置"、
"地层区划"、"出露地层"、"岩性"、"厚度"、"面积"、"坐标"、"长度"、"含有"、"所属年代"、"行政区划"、
"发育"、"古生物"、"海拔"、"属于"、"吞噬"、"侵入"。
统计规则如下:
逐句子(500条)统计,犯错误类型
每个句子需记录如下信息:
24种关系的提取明细,可能遇到如下情况,
1、真实和预测三元组中都有对应关系,统计各自某一关系的个数,该条句子的该种关系记录为(预测样本中包含该关系的个数/真实样本中包含该关系的个数,大于1则记作1)
2、真实三元组中有对应关系,但是预测三元组中没有,该条句子的该种关系记录为0
3、真实三元组中没有对应关系,但是预测三元组中有,该条句子的该种关系记录为(预测样本中包含该关系的个数,取负值)
3扩充、实三元组中没有对应关系,但是预测三元组中有,但不属于24种关系,则新建一列'其他',该条句子的其他关系记录为(预测样本中包含其他系的个数,取负值)
每条句子的评价输出结果应该是长度为25的向量,前24个元素取值取决于预测和真实三元组中包含该关系的个数,最后一个元素取值取决于预测三元组中不属于24种关系的个数
真实和预测样本的格式如下:
[
{
"text": "诺日巴尕日保组原指灰色灰绿色厚层中-细粒岩屑长石砂岩长石石英砂岩长石砂岩偶夹粉砂岩,粘土岩及泥晶灰岩组成,仅见双壳类化石,与上覆九十道班组为连续沉积。",
"triple_list": [
[
"A,
"岩性",
"B"
],
[
"A",
"岩性",
"B"
]
]
},
...
rest 499
]
真实和与预测样本的顺序完全一致,可以使用索引定位
'''
import json
import matplotlib.pyplot as plt
import numpy as np
import set_font
set_font.set_font()
def evaluate_triples(gold_path, pred_path):
# 读取真实和预测的三元组数据
with open(gold_path, 'r', encoding='utf-8') as f:
gold_data = json.load(f)
with open(pred_path, 'r', encoding='utf-8') as f:
pred_data = json.load(f)
# 定义24种关系
relations = [
"出露于", "位于", "整合接触", "不整合接触", "假整合接触", "断层接触", "分布形态", "大地构造位置",
"地层区划", "出露地层", "岩性", "厚度", "面积", "坐标", "长度", "含有", "所属年代", "行政区划",
"发育", "古生物", "海拔", "属于", "吞噬", "侵入"
]
# 初始化错误统计和混乱程度统计
error_stats = [0] * 24
confusion_stats = [0] * 25
gold_counts = [0] * 24
pred_counts = [0] * 24
correct_counts = [0] * 24 # 新增:正确提取的统计
# 遍历每个句子
for gold, pred in zip(gold_data, pred_data):
gold_triples = gold['triple_list']
pred_triples = pred['triple_list']
# 统计每种关系的错误
for i, relation in enumerate(relations):
gold_count = sum(1 for triple in gold_triples if triple[1] == relation)
pred_count = sum(1 for triple in pred_triples if triple[1] == relation)
gold_counts[i] += gold_count
pred_counts[i] += pred_count
# 计算正确提取数
if gold_count > 0 and pred_count > 0:
correct_counts[i] += min(gold_count, pred_count)
if gold_count > 0 and pred_count > 0:
error_stats[i] += min(pred_count / gold_count, 1)
elif gold_count > 0 and pred_count == 0:
error_stats[i] += 0
elif gold_count == 0 and pred_count > 0:
confusion_stats[i] += pred_count
# 统计'其他'关系
other_count = sum(1 for triple in pred_triples if triple[1] not in relations)
confusion_stats[24] += other_count
for i in range(len(relations)):
print(f"{relations[i]}: 真实个数 = {gold_counts[i]}")# , 预测个数 = {pred_counts[i]}")
# 误用统计
# print(f"{relations[i]}: 误用个数 = {confusion_stats[i]}")
# 计算正确提取率
correct_rates = [correct / (gold if gold != 0 else 1) for correct, gold in zip(correct_counts, gold_counts)]
return error_stats, confusion_stats, gold_counts, pred_counts, correct_rates
def plot_results(relations, pred_error_stats, gold_error_stats, confusion_stats, correct_rates):
# 计算预测的error_stats和真实样本的error_stats的比值
ratios = [min(pred / (gold if gold != 0 else 1), 1) for pred, gold in zip(pred_error_stats, gold_error_stats)]
# 打印金样本和预测样本的每个分类匹配的得分,及其比值
# for i in range(len(relations)):
# # print(f"{relations[i]}: 真实得分 = {gold_error_stats[i]}, 预测得分 = {pred_error_stats[i]}, 比值 = {ratios[i]}")
# print(f"{relations[i]}: 匹配比值 = {ratios[i]}")
# 输出匹配分数和混乱程度
# print(f"匹配分数:{sum(pred_error_stats)}")
# print(f"混乱程度:{sum(confusion_stats)}")
# 输出正确提取率
print("\n正确提取率:")
for i in range(len(relations)):
print(f"{relations[i]}: {correct_rates[i]:.2f}")
# # 绘制预测/真实比值的条形图
# x = np.arange(len(relations))
# plt.figure(figsize=(12, 6))
# plt.bar(x, ratios, color='skyblue')
# plt.xticks(x, relations, rotation=45, ha='right')
# plt.ylabel('预测/真实 比值')
# plt.title('预测结果与真实结果的比值')
# plt.tight_layout()
# plt.show()
# # 绘制混乱程度的条形图
# x_confusion = np.arange(len(confusion_stats))
# plt.figure(figsize=(12, 6))
# plt.bar(x_confusion, confusion_stats, color='salmon')
# plt.xticks(x_confusion, relations + ['其他'], rotation=45, ha='right')
# plt.ylabel('混乱程度')
# plt.title('预测结果的混乱程度')
# plt.tight_layout()
# plt.show()
# # 绘制正确提取率的条形图
# x_correct = np.arange(len(correct_rates))
# plt.figure(figsize=(12, 6))
# plt.bar(x_correct, correct_rates, color='lightgreen')
# plt.xticks(x_correct, relations, rotation=45, ha='right')
# plt.ylabel('正确提取率')
# plt.title('正确提取率')
# plt.tight_layout()
# plt.show()
relations = [
"出露于", "位于", "整合接触", "不整合接触", "假整合接触", "断层接触", "分布形态", "大地构造位置",
"地层区划", "出露地层", "岩性", "厚度", "面积", "坐标", "长度", "含有", "所属年代", "行政区划",
"发育", "古生物", "海拔", "属于", "吞噬", "侵入"
]
# 使用示例
gold_path = './data/GT_500.json'
model_paths = [
# # gpt-3.5
'./data/GT_500.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/gpt-3.5-turbo.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/gpt-4o.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/gemini-1.5-pro-002.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/claude-3-5-haiku-20241022.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/deepseek-ai/deepseek-V3.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/deepseek-ai/deepseek-R1.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/meta-llama/Meta-Llama-3.1-405B-Instruct.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/Qwen/Qwen2.5-72B-Instruct.json',
# 'F:/GeoLLM/output/output_result/Task1/nomal/two_shot/deepseek-ai/deepseek-V3.json',
# 'F:/GeoLLM/output/output_result/Task1/knn/three_shot/deepseek-ai/deepseek-R1.json',
# 'F:/GeoLLM/output/output_result/Task1/Knowledge-guided/one_shot/deepseek-ai/deepseek-R1.json'
# 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/deepseek-ai/deepseek-V3.json'
# 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/gpt-3.5-turbo.json',
# 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/gpt-3.5-turbo_0407.json'
]
for model_path in model_paths:
pred_path = model_path
print(pred_path)
# error_stats, confusion_stats, gold_counts, pred_counts = evaluate_triples(gold_path, pred_path)
# # 打印预测样本的每个分类匹配的个数
# for i in range(len(relations)):
# print(relations[i], gold_counts[i], pred_counts[i])
# # 打印金样本和预测样本的每个分类匹配的得分
# for i in range(len(relations)):
# print(relations[i], error_stats[i])
# plot_results(relations, error_stats, gold_counts, confusion_stats)
# 计算预测样本的得分
pred_error_stats, confusion_stats, _, _, correct_rates = evaluate_triples(gold_path, pred_path)
# 计算真实样本的得分
gold_error_stats, _, _, _, _ = evaluate_triples(gold_path, gold_path)
plot_results(relations, pred_error_stats, gold_error_stats, confusion_stats, correct_rates)
# # 预测样本
# model_paths = [
# # # gpt-3.5
# './data/GT_500.json',
# # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/deepseek-ai/deepseek-V3.json', # 零样本
# # 'F:/GeoLLM/output/output_result/Task1/nomal/one_shot/gpt-3p5-turbo.json', # 单样本
# # 'F:/GeoLLM/output/output_result/Task1/nomal/two_shot/gpt-3p5-turbo.json', # 双样本
# # 'F:/GeoLLM/output/output_result/Task1/nomal/three_shot/gpt-3p5-turbo.json', # 三样本
# # 'F:/GeoLLM/output/output_result/Task1/knn/one_shot/gpt-3p5-turbo.json', # KNN单样本
# # 'F:/GeoLLM/output/output_result/Task1/knn/two_shot/gpt-3p5-turbo.json', # KNN双样本
# # 'F:/GeoLLM/output/output_result/Task1/knn/three_shot/gpt-3p5-turbo.json', # KNN三样本
# # 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/deepseek-ai/deepseek-V3.json', # 知识引导单样本
# ]
# for model_path in model_paths:
# # 假设你想要选择第一个路径作为预测样本路径
# pred_path = model_path
# # 使用示例
# error_stats, confusion_stats = evaluate_triples(gold_path, pred_path)
# print("匹配统计结果:\n")
# for i in range(len(relations)):
# print(relations[i], error_stats[i])
# # 使用内置的 sum() 函数计算总和
# print("匹配分数:", sum(error_stats))
# print('--------------------------------')
# print("误用统计:\n")
# for i in range(len(relations)):
# print(relations[i], confusion_stats[i])
# print("混乱程度总和:", sum(confusion_stats))