GeoLLM / output /test2 /eval.py
Pengfa Li
Upload folder using huggingface_hub
badcf3c verified
import json
import numpy as np
from metrics.graph_matching import (
get_triple_match_f1,
get_graph_match_accuracy,
get_bert_score,
get_bleu_rouge,
split_to_edges,
get_tokens,
get_ged
)
def load_data(gold_path, pred_path):
'''
数据加载处理:
只评估在预测数据中出现的文本对应的三元组
自动匹配真实数据和预测数据中的对应项
多维度评估:
Triple Match F1:评估三元组的精确匹配程度
Graph Match Accuracy:评估图结构的匹配程度
BERT Score:评估语义相似度
BLEU & ROUGE:评估文本生成质量
图编辑距离(GED):评估图结构差异
'''
# 加载真实数据
with open(gold_path, 'r', encoding='utf-8') as f:
gold_data = json.load(f)
# 加载预测数据
with open(pred_path, 'r', encoding='utf-8') as f:
pred_data = json.load(f)
# 提取三元组列表
gold_graphs = []
pred_graphs = []
# 确保只评估在预测数据中出现的文本对应的三元组
for pred_item in pred_data:
pred_text = pred_item['text']
# 在gold_data中找到对应的文本
for gold_item in gold_data:
if gold_item['text'] == pred_text:
gold_graphs.append(gold_item['triple_list'])
pred_graphs.append(pred_item['triple_list'])
break
return gold_graphs, pred_graphs
def evaluate_triples(gold_graphs, pred_graphs):
print("开始评估...")
print("="*50)
# 1. Triple Match F1
precision, recall, f1 = get_triple_match_f1(gold_graphs, pred_graphs)
print("Triple Match")
print(f"精确率: {precision:.4f}, 召回率: {recall:.4f}, F1: {f1:.4f}")
# # 2. Graph Match Accuracy
# graph_acc = get_graph_match_accuracy(pred_graphs, gold_graphs)
# print(f"图匹配准确率: {graph_acc:.10f}")
# 3. BERT Score
gold_edges = split_to_edges(gold_graphs)
pred_edges = split_to_edges(pred_graphs)
precisions_BS, recalls_BS, f1s_BS = get_bert_score(gold_edges, pred_edges)
print(f"BERT Score:")
print(f"- Precision: {precisions_BS.mean():.4f}")
print(f"- Recall: {recalls_BS.mean():.4f}")
print(f"- F1: {f1s_BS.mean():.4f}")
# # 4. BLEU & ROUGE
# gold_tokens, pred_tokens = get_tokens(gold_edges, pred_edges)
# p_rouge, r_rouge, f1_rouge, p_bleu, r_bleu, f1_bleu = get_bleu_rouge(
# gold_tokens, pred_tokens, gold_edges, pred_edges
# )
# print(f"\nBLEU分数:")
# print(f"- Precision: {p_bleu.mean():.4f}")
# print(f"- Recall: {r_bleu.mean():.4f}")
# print(f"- F1: {f1_bleu.mean():.4f}")
# print(f"\nROUGE分数:")
# print(f"- Precision: {p_rouge.mean():.4f}")
# print(f"- Recall: {r_rouge.mean():.4f}")
# print(f"- F1: {f1_rouge.mean():.4f}")
# # 5. 图编辑距离(GED)
# total_ged = 0
# for gold, pred in zip(gold_graphs, pred_graphs):
# ged = get_ged(gold, pred)
# total_ged += ged
# avg_ged = total_ged / len(gold_graphs)
# print(f"\n平均图编辑距离: {avg_ged:.4f}")
# 返回所有指标
return {
'triple_match': {
'precision': precision,
'recall': recall,
'f1': f1
},
# 'graph_acc': graph_acc,
'bert_score': {
'precision': precisions_BS.mean(),
'recall': recalls_BS.mean(),
'f1': f1s_BS.mean()
},
# 'bleu': {
# 'precision': p_bleu.mean(),
# 'recall': r_bleu.mean(),
# 'f1': f1_bleu.mean()
# },
# 'rouge': {
# 'precision': p_rouge.mean(),
# 'recall': r_rouge.mean(),
# 'f1': f1_rouge.mean()
# },
# 'ged': avg_ged
}
if __name__ == '__main__':
# 设置文件路径
gold_path = './data/train_triples.json'
pred_path = './output/gpt.json'
# 加载数据
gold_graphs, pred_graphs = load_data(gold_path, pred_path)
# 评估并打印结果
results = evaluate_triples(gold_graphs, pred_graphs)