File size: 4,288 Bytes
badcf3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import json
import numpy as np
from metrics.graph_matching import (
get_triple_match_f1,
get_graph_match_accuracy,
get_bert_score,
get_bleu_rouge,
split_to_edges,
get_tokens,
get_ged
)
def load_data(gold_path, pred_path):
'''
数据加载处理:
只评估在预测数据中出现的文本对应的三元组
自动匹配真实数据和预测数据中的对应项
多维度评估:
Triple Match F1:评估三元组的精确匹配程度
Graph Match Accuracy:评估图结构的匹配程度
BERT Score:评估语义相似度
BLEU & ROUGE:评估文本生成质量
图编辑距离(GED):评估图结构差异
'''
# 加载真实数据
with open(gold_path, 'r', encoding='utf-8') as f:
gold_data = json.load(f)
# 加载预测数据
with open(pred_path, 'r', encoding='utf-8') as f:
pred_data = json.load(f)
# 提取三元组列表
gold_graphs = []
pred_graphs = []
# 确保只评估在预测数据中出现的文本对应的三元组
for pred_item in pred_data:
pred_text = pred_item['text']
# 在gold_data中找到对应的文本
for gold_item in gold_data:
if gold_item['text'] == pred_text:
gold_graphs.append(gold_item['triple_list'])
pred_graphs.append(pred_item['triple_list'])
break
return gold_graphs, pred_graphs
def evaluate_triples(gold_graphs, pred_graphs):
print("开始评估...")
print("="*50)
# 1. Triple Match F1
precision, recall, f1 = get_triple_match_f1(gold_graphs, pred_graphs)
print("Triple Match")
print(f"精确率: {precision:.4f}, 召回率: {recall:.4f}, F1: {f1:.4f}")
# # 2. Graph Match Accuracy
# graph_acc = get_graph_match_accuracy(pred_graphs, gold_graphs)
# print(f"图匹配准确率: {graph_acc:.10f}")
# 3. BERT Score
gold_edges = split_to_edges(gold_graphs)
pred_edges = split_to_edges(pred_graphs)
precisions_BS, recalls_BS, f1s_BS = get_bert_score(gold_edges, pred_edges)
print(f"BERT Score:")
print(f"- Precision: {precisions_BS.mean():.4f}")
print(f"- Recall: {recalls_BS.mean():.4f}")
print(f"- F1: {f1s_BS.mean():.4f}")
# # 4. BLEU & ROUGE
# gold_tokens, pred_tokens = get_tokens(gold_edges, pred_edges)
# p_rouge, r_rouge, f1_rouge, p_bleu, r_bleu, f1_bleu = get_bleu_rouge(
# gold_tokens, pred_tokens, gold_edges, pred_edges
# )
# print(f"\nBLEU分数:")
# print(f"- Precision: {p_bleu.mean():.4f}")
# print(f"- Recall: {r_bleu.mean():.4f}")
# print(f"- F1: {f1_bleu.mean():.4f}")
# print(f"\nROUGE分数:")
# print(f"- Precision: {p_rouge.mean():.4f}")
# print(f"- Recall: {r_rouge.mean():.4f}")
# print(f"- F1: {f1_rouge.mean():.4f}")
# # 5. 图编辑距离(GED)
# total_ged = 0
# for gold, pred in zip(gold_graphs, pred_graphs):
# ged = get_ged(gold, pred)
# total_ged += ged
# avg_ged = total_ged / len(gold_graphs)
# print(f"\n平均图编辑距离: {avg_ged:.4f}")
# 返回所有指标
return {
'triple_match': {
'precision': precision,
'recall': recall,
'f1': f1
},
# 'graph_acc': graph_acc,
'bert_score': {
'precision': precisions_BS.mean(),
'recall': recalls_BS.mean(),
'f1': f1s_BS.mean()
},
# 'bleu': {
# 'precision': p_bleu.mean(),
# 'recall': r_bleu.mean(),
# 'f1': f1_bleu.mean()
# },
# 'rouge': {
# 'precision': p_rouge.mean(),
# 'recall': r_rouge.mean(),
# 'f1': f1_rouge.mean()
# },
# 'ged': avg_ged
}
if __name__ == '__main__':
# 设置文件路径
gold_path = './data/train_triples.json'
pred_path = './output/gpt.json'
# 加载数据
gold_graphs, pred_graphs = load_data(gold_path, pred_path)
# 评估并打印结果
results = evaluate_triples(gold_graphs, pred_graphs) |