| ''' | |
| 用于统计500条三元组提取中的犯错类型 | |
| 规则如下: | |
| 正确被真实三元组记录的24种关系如下: | |
| 出露于"、"位于"、"整合接触"、"不整合接触"、"假整合接触"、"断层接触"、"分布形态"、"大地构造位置"、 | |
| "地层区划"、"出露地层"、"岩性"、"厚度"、"面积"、"坐标"、"长度"、"含有"、"所属年代"、"行政区划"、 | |
| "发育"、"古生物"、"海拔"、"属于"、"吞噬"、"侵入"。 | |
| 统计规则如下: | |
| 逐句子(500条)统计,犯错误类型 | |
| 每个句子需记录如下信息: | |
| 24种关系的提取明细,可能遇到如下情况, | |
| 1、真实和预测三元组中都有对应关系,统计各自某一关系的个数,该条句子的该种关系记录为(预测样本中包含该关系的个数/真实样本中包含该关系的个数,大于1则记作1) | |
| 2、真实三元组中有对应关系,但是预测三元组中没有,该条句子的该种关系记录为0 | |
| 3、真实三元组中没有对应关系,但是预测三元组中有,该条句子的该种关系记录为(预测样本中包含该关系的个数,取负值) | |
| 3扩充、实三元组中没有对应关系,但是预测三元组中有,但不属于24种关系,则新建一列'其他',该条句子的其他关系记录为(预测样本中包含其他系的个数,取负值) | |
| 每条句子的评价输出结果应该是长度为25的向量,前24个元素取值取决于预测和真实三元组中包含该关系的个数,最后一个元素取值取决于预测三元组中不属于24种关系的个数 | |
| 真实和预测样本的格式如下: | |
| [ | |
| { | |
| "text": "诺日巴尕日保组原指灰色灰绿色厚层中-细粒岩屑长石砂岩长石石英砂岩长石砂岩偶夹粉砂岩,粘土岩及泥晶灰岩组成,仅见双壳类化石,与上覆九十道班组为连续沉积。", | |
| "triple_list": [ | |
| [ | |
| "A, | |
| "岩性", | |
| "B" | |
| ], | |
| [ | |
| "A", | |
| "岩性", | |
| "B" | |
| ] | |
| ] | |
| }, | |
| ... | |
| rest 499 | |
| ] | |
| 真实和与预测样本的顺序完全一致,可以使用索引定位 | |
| ''' | |
| import json | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import set_font | |
| set_font.set_font() | |
| def evaluate_triples(gold_path, pred_path): | |
| # 读取真实和预测的三元组数据 | |
| with open(gold_path, 'r', encoding='utf-8') as f: | |
| gold_data = json.load(f) | |
| with open(pred_path, 'r', encoding='utf-8') as f: | |
| pred_data = json.load(f) | |
| # 定义24种关系 | |
| relations = [ | |
| "出露于", "位于", "整合接触", "不整合接触", "假整合接触", "断层接触", "分布形态", "大地构造位置", | |
| "地层区划", "出露地层", "岩性", "厚度", "面积", "坐标", "长度", "含有", "所属年代", "行政区划", | |
| "发育", "古生物", "海拔", "属于", "吞噬", "侵入" | |
| ] | |
| # 初始化错误统计和混乱程度统计 | |
| error_stats = [0] * 24 | |
| confusion_stats = [0] * 25 | |
| gold_counts = [0] * 24 | |
| pred_counts = [0] * 24 | |
| correct_counts = [0] * 24 # 新增:正确提取的统计 | |
| # 遍历每个句子 | |
| for gold, pred in zip(gold_data, pred_data): | |
| gold_triples = gold['triple_list'] | |
| pred_triples = pred['triple_list'] | |
| # 统计每种关系的错误 | |
| for i, relation in enumerate(relations): | |
| gold_count = sum(1 for triple in gold_triples if triple[1] == relation) | |
| pred_count = sum(1 for triple in pred_triples if triple[1] == relation) | |
| gold_counts[i] += gold_count | |
| pred_counts[i] += pred_count | |
| # 计算正确提取数 | |
| if gold_count > 0 and pred_count > 0: | |
| correct_counts[i] += min(gold_count, pred_count) | |
| if gold_count > 0 and pred_count > 0: | |
| error_stats[i] += min(pred_count / gold_count, 1) | |
| elif gold_count > 0 and pred_count == 0: | |
| error_stats[i] += 0 | |
| elif gold_count == 0 and pred_count > 0: | |
| confusion_stats[i] += pred_count | |
| # 统计'其他'关系 | |
| other_count = sum(1 for triple in pred_triples if triple[1] not in relations) | |
| confusion_stats[24] += other_count | |
| for i in range(len(relations)): | |
| print(f"{relations[i]}: 真实个数 = {gold_counts[i]}")# , 预测个数 = {pred_counts[i]}") | |
| # 误用统计 | |
| # print(f"{relations[i]}: 误用个数 = {confusion_stats[i]}") | |
| # 计算正确提取率 | |
| correct_rates = [correct / (gold if gold != 0 else 1) for correct, gold in zip(correct_counts, gold_counts)] | |
| return error_stats, confusion_stats, gold_counts, pred_counts, correct_rates | |
| def plot_results(relations, pred_error_stats, gold_error_stats, confusion_stats, correct_rates): | |
| # 计算预测的error_stats和真实样本的error_stats的比值 | |
| ratios = [min(pred / (gold if gold != 0 else 1), 1) for pred, gold in zip(pred_error_stats, gold_error_stats)] | |
| # 打印金样本和预测样本的每个分类匹配的得分,及其比值 | |
| # for i in range(len(relations)): | |
| # # print(f"{relations[i]}: 真实得分 = {gold_error_stats[i]}, 预测得分 = {pred_error_stats[i]}, 比值 = {ratios[i]}") | |
| # print(f"{relations[i]}: 匹配比值 = {ratios[i]}") | |
| # 输出匹配分数和混乱程度 | |
| # print(f"匹配分数:{sum(pred_error_stats)}") | |
| # print(f"混乱程度:{sum(confusion_stats)}") | |
| # 输出正确提取率 | |
| print("\n正确提取率:") | |
| for i in range(len(relations)): | |
| print(f"{relations[i]}: {correct_rates[i]:.2f}") | |
| # # 绘制预测/真实比值的条形图 | |
| # x = np.arange(len(relations)) | |
| # plt.figure(figsize=(12, 6)) | |
| # plt.bar(x, ratios, color='skyblue') | |
| # plt.xticks(x, relations, rotation=45, ha='right') | |
| # plt.ylabel('预测/真实 比值') | |
| # plt.title('预测结果与真实结果的比值') | |
| # plt.tight_layout() | |
| # plt.show() | |
| # # 绘制混乱程度的条形图 | |
| # x_confusion = np.arange(len(confusion_stats)) | |
| # plt.figure(figsize=(12, 6)) | |
| # plt.bar(x_confusion, confusion_stats, color='salmon') | |
| # plt.xticks(x_confusion, relations + ['其他'], rotation=45, ha='right') | |
| # plt.ylabel('混乱程度') | |
| # plt.title('预测结果的混乱程度') | |
| # plt.tight_layout() | |
| # plt.show() | |
| # # 绘制正确提取率的条形图 | |
| # x_correct = np.arange(len(correct_rates)) | |
| # plt.figure(figsize=(12, 6)) | |
| # plt.bar(x_correct, correct_rates, color='lightgreen') | |
| # plt.xticks(x_correct, relations, rotation=45, ha='right') | |
| # plt.ylabel('正确提取率') | |
| # plt.title('正确提取率') | |
| # plt.tight_layout() | |
| # plt.show() | |
| relations = [ | |
| "出露于", "位于", "整合接触", "不整合接触", "假整合接触", "断层接触", "分布形态", "大地构造位置", | |
| "地层区划", "出露地层", "岩性", "厚度", "面积", "坐标", "长度", "含有", "所属年代", "行政区划", | |
| "发育", "古生物", "海拔", "属于", "吞噬", "侵入" | |
| ] | |
| # 使用示例 | |
| gold_path = './data/GT_500.json' | |
| model_paths = [ | |
| # # gpt-3.5 | |
| './data/GT_500.json', | |
| # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/gpt-3.5-turbo.json', | |
| # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/gpt-4o.json', | |
| # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/gemini-1.5-pro-002.json', | |
| # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/claude-3-5-haiku-20241022.json', | |
| # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/deepseek-ai/deepseek-V3.json', | |
| # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/deepseek-ai/deepseek-R1.json', | |
| # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/meta-llama/Meta-Llama-3.1-405B-Instruct.json', | |
| # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/Qwen/Qwen2.5-72B-Instruct.json', | |
| # 'F:/GeoLLM/output/output_result/Task1/nomal/two_shot/deepseek-ai/deepseek-V3.json', | |
| # 'F:/GeoLLM/output/output_result/Task1/knn/three_shot/deepseek-ai/deepseek-R1.json', | |
| # 'F:/GeoLLM/output/output_result/Task1/Knowledge-guided/one_shot/deepseek-ai/deepseek-R1.json' | |
| # 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/deepseek-ai/deepseek-V3.json' | |
| # 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/gpt-3.5-turbo.json', | |
| # 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/gpt-3.5-turbo_0407.json' | |
| ] | |
| for model_path in model_paths: | |
| pred_path = model_path | |
| print(pred_path) | |
| # error_stats, confusion_stats, gold_counts, pred_counts = evaluate_triples(gold_path, pred_path) | |
| # # 打印预测样本的每个分类匹配的个数 | |
| # for i in range(len(relations)): | |
| # print(relations[i], gold_counts[i], pred_counts[i]) | |
| # # 打印金样本和预测样本的每个分类匹配的得分 | |
| # for i in range(len(relations)): | |
| # print(relations[i], error_stats[i]) | |
| # plot_results(relations, error_stats, gold_counts, confusion_stats) | |
| # 计算预测样本的得分 | |
| pred_error_stats, confusion_stats, _, _, correct_rates = evaluate_triples(gold_path, pred_path) | |
| # 计算真实样本的得分 | |
| gold_error_stats, _, _, _, _ = evaluate_triples(gold_path, gold_path) | |
| plot_results(relations, pred_error_stats, gold_error_stats, confusion_stats, correct_rates) | |
| # # 预测样本 | |
| # model_paths = [ | |
| # # # gpt-3.5 | |
| # './data/GT_500.json', | |
| # # 'F:/GeoLLM/output/output_result/Task1/nomal/zero_shot/deepseek-ai/deepseek-V3.json', # 零样本 | |
| # # 'F:/GeoLLM/output/output_result/Task1/nomal/one_shot/gpt-3p5-turbo.json', # 单样本 | |
| # # 'F:/GeoLLM/output/output_result/Task1/nomal/two_shot/gpt-3p5-turbo.json', # 双样本 | |
| # # 'F:/GeoLLM/output/output_result/Task1/nomal/three_shot/gpt-3p5-turbo.json', # 三样本 | |
| # # 'F:/GeoLLM/output/output_result/Task1/knn/one_shot/gpt-3p5-turbo.json', # KNN单样本 | |
| # # 'F:/GeoLLM/output/output_result/Task1/knn/two_shot/gpt-3p5-turbo.json', # KNN双样本 | |
| # # 'F:/GeoLLM/output/output_result/Task1/knn/three_shot/gpt-3p5-turbo.json', # KNN三样本 | |
| # # 'F:/GeoLLM/output/Knowledge-guided_rerun/one_shot/deepseek-ai/deepseek-V3.json', # 知识引导单样本 | |
| # ] | |
| # for model_path in model_paths: | |
| # # 假设你想要选择第一个路径作为预测样本路径 | |
| # pred_path = model_path | |
| # # 使用示例 | |
| # error_stats, confusion_stats = evaluate_triples(gold_path, pred_path) | |
| # print("匹配统计结果:\n") | |
| # for i in range(len(relations)): | |
| # print(relations[i], error_stats[i]) | |
| # # 使用内置的 sum() 函数计算总和 | |
| # print("匹配分数:", sum(error_stats)) | |
| # print('--------------------------------') | |
| # print("误用统计:\n") | |
| # for i in range(len(relations)): | |
| # print(relations[i], confusion_stats[i]) | |
| # print("混乱程度总和:", sum(confusion_stats)) |