| from collections import defaultdict |
|
|
| import numpy as np |
| import torch |
| from seqeval.metrics.v1 import _prf_divide |
|
|
|
|
| def extract_tp_actual_correct(y_true, y_pred): |
| entities_true = defaultdict(set) |
| entities_pred = defaultdict(set) |
|
|
| for type_name, (start, end), idx in y_true: |
| entities_true[type_name].add((start, end, idx)) |
| for type_name, (start, end), idx in y_pred: |
| entities_pred[type_name].add((start, end, idx)) |
|
|
| target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys())) |
|
|
| tp_sum = np.array([], dtype=np.int32) |
| pred_sum = np.array([], dtype=np.int32) |
| true_sum = np.array([], dtype=np.int32) |
| for type_name in target_names: |
| entities_true_type = entities_true.get(type_name, set()) |
| entities_pred_type = entities_pred.get(type_name, set()) |
| tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type)) |
| pred_sum = np.append(pred_sum, len(entities_pred_type)) |
| true_sum = np.append(true_sum, len(entities_true_type)) |
|
|
| return pred_sum, tp_sum, true_sum, target_names |
|
|
|
|
| def flatten_for_eval(y_true, y_pred): |
| all_true = [] |
| all_pred = [] |
|
|
| for i, (true, pred) in enumerate(zip(y_true, y_pred)): |
| all_true.extend([t + [i] for t in true]) |
| all_pred.extend([p + [i] for p in pred]) |
|
|
| return all_true, all_pred |
|
|
|
|
| def compute_prf(y_true, y_pred, average='micro'): |
| y_true, y_pred = flatten_for_eval(y_true, y_pred) |
|
|
| pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred) |
|
|
| if average == 'micro': |
| tp_sum = np.array([tp_sum.sum()]) |
| pred_sum = np.array([pred_sum.sum()]) |
| true_sum = np.array([true_sum.sum()]) |
|
|
| precision = _prf_divide( |
| numerator=tp_sum, |
| denominator=pred_sum, |
| metric='precision', |
| modifier='predicted', |
| average=average, |
| warn_for=('precision', 'recall', 'f-score'), |
| zero_division='warn' |
| ) |
|
|
| recall = _prf_divide( |
| numerator=tp_sum, |
| denominator=true_sum, |
| metric='recall', |
| modifier='true', |
| average=average, |
| warn_for=('precision', 'recall', 'f-score'), |
| zero_division='warn' |
| ) |
|
|
| denominator = precision + recall |
| denominator[denominator == 0.] = 1 |
| f_score = 2 * (precision * recall) / denominator |
|
|
| return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]} |
|
|
|
|
| class Evaluator: |
| def __init__(self, all_true, all_outs): |
| self.all_true = all_true |
| self.all_outs = all_outs |
|
|
| def get_entities_fr(self, ents): |
| all_ents = [] |
| for s, e, lab in ents: |
| all_ents.append([lab, (s, e)]) |
| return all_ents |
|
|
| def transform_data(self): |
| all_true_ent = [] |
| all_outs_ent = [] |
| for i, j in zip(self.all_true, self.all_outs): |
| e = self.get_entities_fr(i) |
| all_true_ent.append(e) |
| e = self.get_entities_fr(j) |
| all_outs_ent.append(e) |
| return all_true_ent, all_outs_ent |
|
|
| @torch.no_grad() |
| def evaluate(self): |
| all_true_typed, all_outs_typed = self.transform_data() |
| precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values() |
| output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n" |
| return output_str, f1 |
|
|
|
|
| def is_nested(idx1, idx2): |
| |
| return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1]) |
|
|
|
|
| def has_overlapping(idx1, idx2): |
| overlapping = True |
| if idx1[:2] == idx2[:2]: |
| return overlapping |
| if (idx1[0] > idx2[1] or idx2[0] > idx1[1]): |
| overlapping = False |
| return overlapping |
|
|
|
|
| def has_overlapping_nested(idx1, idx2): |
| |
| if idx1[:2] == idx2[:2]: |
| return True |
| if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2: |
| return False |
| else: |
| return True |
|
|
|
|
| def greedy_search(spans, flat_ner=True): |
|
|
| if flat_ner: |
| has_ov = has_overlapping |
| else: |
| has_ov = has_overlapping_nested |
|
|
| new_list = [] |
| span_prob = sorted(spans, key=lambda x: -x[-1]) |
| for i in range(len(spans)): |
| b = span_prob[i] |
| flag = False |
| for new in new_list: |
| if has_ov(b[:-1], new): |
| flag = True |
| break |
| if not flag: |
| new_list.append(b[:-1]) |
| new_list = sorted(new_list, key=lambda x: x[0]) |
| return new_list |
|
|