Spaces:
Runtime error
Runtime error
Commit
Β·
ba6a59b
1
Parent(s):
3360bee
Add main evaluation method
Browse files- relation_extraction.py +77 -5
relation_extraction.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
| 15 |
|
| 16 |
import evaluate
|
| 17 |
import datasets
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
# TODO: Add BibTeX citation
|
|
@@ -86,10 +87,81 @@ class relation_extraction(evaluate.Metric):
|
|
| 86 |
# TODO: Download external resources if needed
|
| 87 |
pass
|
| 88 |
|
| 89 |
-
def _compute(self,
|
| 90 |
"""Returns the scores"""
|
| 91 |
# TODO: Compute the different scores of the module
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
import evaluate
|
| 17 |
import datasets
|
| 18 |
+
import numpy as np
|
| 19 |
|
| 20 |
|
| 21 |
# TODO: Add BibTeX citation
|
|
|
|
| 87 |
# TODO: Download external resources if needed
|
| 88 |
pass
|
| 89 |
|
| 90 |
+
def _compute(self, pred_relations, gt_relations, mode="strict", relation_types=[]):
|
| 91 |
"""Returns the scores"""
|
| 92 |
# TODO: Compute the different scores of the module
|
| 93 |
+
|
| 94 |
+
assert mode in ["strict", "boundaries"]
|
| 95 |
+
|
| 96 |
+
# construct relation_types from ground truth if not given
|
| 97 |
+
if len(relation_types) == 0:
|
| 98 |
+
for triplets in gt_relations:
|
| 99 |
+
for triplet in triplets:
|
| 100 |
+
relation = triplet["type"]
|
| 101 |
+
if relation not in relation_types:
|
| 102 |
+
relation_types.append(relation)
|
| 103 |
+
|
| 104 |
+
scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}
|
| 105 |
+
|
| 106 |
+
# Count GT relations and Predicted relations
|
| 107 |
+
n_sents = len(gt_relations)
|
| 108 |
+
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
|
| 109 |
+
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
|
| 110 |
+
|
| 111 |
+
# Count TP, FP and FN per type
|
| 112 |
+
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
|
| 113 |
+
for rel_type in relation_types:
|
| 114 |
+
# strict mode takes argument types into account
|
| 115 |
+
if mode == "strict":
|
| 116 |
+
pred_rels = {(rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) for rel in pred_sent if
|
| 117 |
+
rel["type"] == rel_type}
|
| 118 |
+
gt_rels = {(rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) for rel in gt_sent if
|
| 119 |
+
rel["type"] == rel_type}
|
| 120 |
+
|
| 121 |
+
# boundaries mode only takes argument spans into account
|
| 122 |
+
elif mode == "boundaries":
|
| 123 |
+
pred_rels = {(rel["head"], rel["tail"]) for rel in pred_sent if rel["type"] == rel_type}
|
| 124 |
+
gt_rels = {(rel["head"], rel["tail"]) for rel in gt_sent if rel["type"] == rel_type}
|
| 125 |
+
|
| 126 |
+
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
|
| 127 |
+
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
|
| 128 |
+
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
|
| 129 |
+
|
| 130 |
+
# Compute per entity Precision / Recall / F1
|
| 131 |
+
for rel_type in scores.keys():
|
| 132 |
+
if scores[rel_type]["tp"]:
|
| 133 |
+
scores[rel_type]["p"] = 100 * scores[rel_type]["tp"] / (scores[rel_type]["fp"] + scores[rel_type]["tp"])
|
| 134 |
+
scores[rel_type]["r"] = 100 * scores[rel_type]["tp"] / (scores[rel_type]["fn"] + scores[rel_type]["tp"])
|
| 135 |
+
else:
|
| 136 |
+
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
|
| 137 |
+
|
| 138 |
+
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
|
| 139 |
+
scores[rel_type]["f1"] = 2 * scores[rel_type]["p"] * scores[rel_type]["r"] / (
|
| 140 |
+
scores[rel_type]["p"] + scores[rel_type]["r"])
|
| 141 |
+
else:
|
| 142 |
+
scores[rel_type]["f1"] = 0
|
| 143 |
+
|
| 144 |
+
# Compute micro F1 Scores
|
| 145 |
+
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
|
| 146 |
+
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
|
| 147 |
+
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
|
| 148 |
+
|
| 149 |
+
if tp:
|
| 150 |
+
precision = 100 * tp / (tp + fp)
|
| 151 |
+
recall = 100 * tp / (tp + fn)
|
| 152 |
+
f1 = 2 * precision * recall / (precision + recall)
|
| 153 |
+
|
| 154 |
+
else:
|
| 155 |
+
precision, recall, f1 = 0, 0, 0
|
| 156 |
+
|
| 157 |
+
scores["ALL"]["p"] = precision
|
| 158 |
+
scores["ALL"]["r"] = recall
|
| 159 |
+
scores["ALL"]["f1"] = f1
|
| 160 |
+
scores["ALL"]["tp"] = tp
|
| 161 |
+
scores["ALL"]["fp"] = fp
|
| 162 |
+
scores["ALL"]["fn"] = fn
|
| 163 |
+
|
| 164 |
+
# Compute Macro F1 Scores
|
| 165 |
+
scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in relation_types])
|
| 166 |
+
scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in relation_types])
|
| 167 |
+
scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in relation_types])
|