Spaces:
Sleeping
Sleeping
| from typing import Dict | |
| def classify_predictions(gold: dict, pred: dict, union=False) -> Dict[str, float]: | |
| """ | |
| Returns true positives, false positives, and false negatives for one example | |
| If union is True, then disregards the type of the tag and only considers the union of all tags | |
| """ | |
| n_tp = 0 | |
| n_fp = 0 | |
| n_fn = 0 | |
| if union: | |
| gold_phrases = set(phrase for phrases in gold.values() for phrase in phrases) | |
| pred_phrases = set(phrase for phrases in pred.values() for phrase in phrases) | |
| n_tp = len(gold_phrases & pred_phrases) | |
| n_fp = len(pred_phrases - gold_phrases) | |
| n_fn = len(gold_phrases - pred_phrases) | |
| return n_tp, n_fp, n_fn | |
| for tag in set(gold.keys()).union(pred.keys()): | |
| gold_phrases = set(gold.get(tag, [])) | |
| pred_phrases = set(pred.get(tag, [])) | |
| n_tp += len(gold_phrases & pred_phrases) | |
| n_fp += len(pred_phrases - gold_phrases) | |
| n_fn += len(gold_phrases - pred_phrases) | |
| return n_tp, n_fp, n_fn | |
| def compute_metrics(running_time, pred_times, runtype, eval_metrics=None): | |
| metrics = {} | |
| metrics["avg_pred_response_time_per_sentence"] = ( | |
| round(sum(pred_times) / len(pred_times), 4) if pred_times else 0 | |
| ) | |
| metrics["total_time"] = round(running_time, 4) | |
| if runtype == "eval" and eval_metrics is not None: | |
| n_tp, n_fp, n_fn, n_tp_union, n_fp_union, n_fn_union = eval_metrics | |
| precision = round(n_tp / (n_tp + n_fp) if (n_tp + n_fp) > 0 else 0, 4) | |
| recall = round(n_tp / (n_tp + n_fn) if (n_tp + n_fn) > 0 else 0, 4) | |
| f1 = round( | |
| ( | |
| 2 * (precision * recall) / (precision + recall) | |
| if (precision + recall) > 0 | |
| else 0 | |
| ), | |
| 4, | |
| ) | |
| union_precision = round( | |
| ( | |
| n_tp_union / (n_tp_union + n_fp_union) | |
| if (n_tp_union + n_fp_union) > 0 | |
| else 0 | |
| ), | |
| 4, | |
| ) | |
| union_recall = round( | |
| ( | |
| n_tp_union / (n_tp_union + n_fn_union) | |
| if (n_tp_union + n_fn_union) > 0 | |
| else 0 | |
| ), | |
| 4, | |
| ) | |
| union_f1 = round( | |
| ( | |
| 2 * (union_precision * union_recall) / (union_precision + union_recall) | |
| if (union_precision + union_recall) > 0 | |
| else 0 | |
| ), | |
| 4, | |
| ) | |
| metrics.update( | |
| { | |
| "precision": precision, | |
| "recall": recall, | |
| "f1": f1, | |
| "union_precision": union_precision, | |
| "union_recall": union_recall, | |
| "union_f1": union_f1, | |
| } | |
| ) | |
| return metrics | |