import json import numpy as np from metrics.graph_matching import ( get_triple_match_f1, get_graph_match_accuracy, get_bert_score, get_bleu_rouge, split_to_edges, get_tokens, get_ged ) def load_data(gold_path, pred_path): ''' 数据加载处理: 只评估在预测数据中出现的文本对应的三元组 自动匹配真实数据和预测数据中的对应项 多维度评估: Triple Match F1:评估三元组的精确匹配程度 Graph Match Accuracy:评估图结构的匹配程度 BERT Score:评估语义相似度 BLEU & ROUGE:评估文本生成质量 图编辑距离(GED):评估图结构差异 ''' # 加载真实数据 with open(gold_path, 'r', encoding='utf-8') as f: gold_data = json.load(f) # 加载预测数据 with open(pred_path, 'r', encoding='utf-8') as f: pred_data = json.load(f) # 提取三元组列表 gold_graphs = [] pred_graphs = [] # 确保只评估在预测数据中出现的文本对应的三元组 for pred_item in pred_data: pred_text = pred_item['text'] # 在gold_data中找到对应的文本 for gold_item in gold_data: if gold_item['text'] == pred_text: gold_graphs.append(gold_item['triple_list']) pred_graphs.append(pred_item['triple_list']) break return gold_graphs, pred_graphs def evaluate_triples(gold_graphs, pred_graphs): print("开始评估...") print("="*50) # 1. Triple Match F1 precision, recall, f1 = get_triple_match_f1(gold_graphs, pred_graphs) print("Triple Match") print(f"精确率: {precision:.4f}, 召回率: {recall:.4f}, F1: {f1:.4f}") # # 2. Graph Match Accuracy # graph_acc = get_graph_match_accuracy(pred_graphs, gold_graphs) # print(f"图匹配准确率: {graph_acc:.10f}") # 3. BERT Score gold_edges = split_to_edges(gold_graphs) pred_edges = split_to_edges(pred_graphs) precisions_BS, recalls_BS, f1s_BS = get_bert_score(gold_edges, pred_edges) print(f"BERT Score:") print(f"- Precision: {precisions_BS.mean():.4f}") print(f"- Recall: {recalls_BS.mean():.4f}") print(f"- F1: {f1s_BS.mean():.4f}") # # 4. BLEU & ROUGE # gold_tokens, pred_tokens = get_tokens(gold_edges, pred_edges) # p_rouge, r_rouge, f1_rouge, p_bleu, r_bleu, f1_bleu = get_bleu_rouge( # gold_tokens, pred_tokens, gold_edges, pred_edges # ) # print(f"\nBLEU分数:") # print(f"- Precision: {p_bleu.mean():.4f}") # print(f"- Recall: {r_bleu.mean():.4f}") # print(f"- F1: {f1_bleu.mean():.4f}") # print(f"\nROUGE分数:") # print(f"- Precision: {p_rouge.mean():.4f}") # print(f"- Recall: {r_rouge.mean():.4f}") # print(f"- F1: {f1_rouge.mean():.4f}") # # 5. 图编辑距离(GED) # total_ged = 0 # for gold, pred in zip(gold_graphs, pred_graphs): # ged = get_ged(gold, pred) # total_ged += ged # avg_ged = total_ged / len(gold_graphs) # print(f"\n平均图编辑距离: {avg_ged:.4f}") # 返回所有指标 return { 'triple_match': { 'precision': precision, 'recall': recall, 'f1': f1 }, # 'graph_acc': graph_acc, 'bert_score': { 'precision': precisions_BS.mean(), 'recall': recalls_BS.mean(), 'f1': f1s_BS.mean() }, # 'bleu': { # 'precision': p_bleu.mean(), # 'recall': r_bleu.mean(), # 'f1': f1_bleu.mean() # }, # 'rouge': { # 'precision': p_rouge.mean(), # 'recall': r_rouge.mean(), # 'f1': f1_rouge.mean() # }, # 'ged': avg_ged } if __name__ == '__main__': # 设置文件路径 gold_path = './data/train_triples.json' pred_path = './output/gpt.json' # 加载数据 gold_graphs, pred_graphs = load_data(gold_path, pred_path) # 评估并打印结果 results = evaluate_triples(gold_graphs, pred_graphs)