| | import argparse |
| | import json |
| | import logging |
| | import os |
| | import pprint |
| | from collections import Counter, defaultdict, namedtuple |
| | from dataclasses import dataclass |
| | from itertools import chain |
| | from typing import Any, Callable, Dict, List, Set, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | from BERT_rationale_benchmark.utils import (Annotation, Evidence, |
| | annotations_from_jsonl, |
| | load_documents, |
| | load_flattened_documents, |
| | load_jsonl) |
| | from scipy.stats import entropy |
| | from sklearn.metrics import (accuracy_score, auc, average_precision_score, |
| | classification_report, precision_recall_curve, |
| | roc_auc_score) |
| |
|
| | logging.basicConfig( |
| | level=logging.DEBUG, format="%(relativeCreated)6d %(threadName)s %(message)s" |
| | ) |
| |
|
| |
|
| | |
| | @dataclass(eq=True, frozen=True) |
| | class Rationale: |
| | ann_id: str |
| | docid: str |
| | start_token: int |
| | end_token: int |
| |
|
| | def to_token_level(self) -> List["Rationale"]: |
| | ret = [] |
| | for t in range(self.start_token, self.end_token): |
| | ret.append(Rationale(self.ann_id, self.docid, t, t + 1)) |
| | return ret |
| |
|
| | @classmethod |
| | def from_annotation(cls, ann: Annotation) -> List["Rationale"]: |
| | ret = [] |
| | for ev_group in ann.evidences: |
| | for ev in ev_group: |
| | ret.append( |
| | Rationale(ann.annotation_id, ev.docid, ev.start_token, ev.end_token) |
| | ) |
| | return ret |
| |
|
| | @classmethod |
| | def from_instance(cls, inst: dict) -> List["Rationale"]: |
| | ret = [] |
| | for rat in inst["rationales"]: |
| | for pred in rat.get("hard_rationale_predictions", []): |
| | ret.append( |
| | Rationale( |
| | inst["annotation_id"], |
| | rat["docid"], |
| | pred["start_token"], |
| | pred["end_token"], |
| | ) |
| | ) |
| | return ret |
| |
|
| |
|
| | @dataclass(eq=True, frozen=True) |
| | class PositionScoredDocument: |
| | ann_id: str |
| | docid: str |
| | scores: Tuple[float] |
| | truths: Tuple[bool] |
| |
|
| | @classmethod |
| | def from_results( |
| | cls, |
| | instances: List[dict], |
| | annotations: List[Annotation], |
| | docs: Dict[str, List[Any]], |
| | use_tokens: bool = True, |
| | ) -> List["PositionScoredDocument"]: |
| | """Creates a paired list of annotation ids/docids/predictions/truth values""" |
| | key_to_annotation = dict() |
| | for ann in annotations: |
| | for ev in chain.from_iterable(ann.evidences): |
| | key = (ann.annotation_id, ev.docid) |
| | if key not in key_to_annotation: |
| | key_to_annotation[key] = [False for _ in docs[ev.docid]] |
| | if use_tokens: |
| | start, end = ev.start_token, ev.end_token |
| | else: |
| | start, end = ev.start_sentence, ev.end_sentence |
| | for t in range(start, end): |
| | key_to_annotation[key][t] = True |
| | ret = [] |
| | if use_tokens: |
| | field = "soft_rationale_predictions" |
| | else: |
| | field = "soft_sentence_predictions" |
| | for inst in instances: |
| | for rat in inst["rationales"]: |
| | docid = rat["docid"] |
| | scores = rat[field] |
| | key = (inst["annotation_id"], docid) |
| | assert len(scores) == len(docs[docid]) |
| | if key in key_to_annotation: |
| | assert len(scores) == len(key_to_annotation[key]) |
| | else: |
| | |
| | key_to_annotation[key] = [False for _ in docs[docid]] |
| | ret.append( |
| | PositionScoredDocument( |
| | inst["annotation_id"], |
| | docid, |
| | tuple(scores), |
| | tuple(key_to_annotation[key]), |
| | ) |
| | ) |
| | return ret |
| |
|
| |
|
| | def _f1(_p, _r): |
| | if _p == 0 or _r == 0: |
| | return 0 |
| | return 2 * _p * _r / (_p + _r) |
| |
|
| |
|
| | def _keyed_rationale_from_list( |
| | rats: List[Rationale], |
| | ) -> Dict[Tuple[str, str], Rationale]: |
| | ret = defaultdict(set) |
| | for r in rats: |
| | ret[(r.ann_id, r.docid)].add(r) |
| | return ret |
| |
|
| |
|
| | def partial_match_score( |
| | truth: List[Rationale], pred: List[Rationale], thresholds: List[float] |
| | ) -> List[Dict[str, Any]]: |
| | """Computes a partial match F1 |
| | |
| | Computes an instance-level (annotation) micro- and macro-averaged F1 score. |
| | True Positives are computed by using intersection-over-union and |
| | thresholding the resulting intersection-over-union fraction. |
| | |
| | Micro-average results are computed by ignoring instance level distinctions |
| | in the TP calculation (and recall, and precision, and finally the F1 of |
| | those numbers). Macro-average results are computed first by measuring |
| | instance (annotation + document) precisions and recalls, averaging those, |
| | and finally computing an F1 of the resulting average. |
| | """ |
| |
|
| | ann_to_rat = _keyed_rationale_from_list(truth) |
| | pred_to_rat = _keyed_rationale_from_list(pred) |
| |
|
| | num_classifications = {k: len(v) for k, v in pred_to_rat.items()} |
| | num_truth = {k: len(v) for k, v in ann_to_rat.items()} |
| | ious = defaultdict(dict) |
| | for k in set(ann_to_rat.keys()) | set(pred_to_rat.keys()): |
| | for p in pred_to_rat.get(k, []): |
| | best_iou = 0.0 |
| | for t in ann_to_rat.get(k, []): |
| | num = len( |
| | set(range(p.start_token, p.end_token)) |
| | & set(range(t.start_token, t.end_token)) |
| | ) |
| | denom = len( |
| | set(range(p.start_token, p.end_token)) |
| | | set(range(t.start_token, t.end_token)) |
| | ) |
| | iou = 0 if denom == 0 else num / denom |
| | if iou > best_iou: |
| | best_iou = iou |
| | ious[k][p] = best_iou |
| | scores = [] |
| | for threshold in thresholds: |
| | threshold_tps = dict() |
| | for k, vs in ious.items(): |
| | threshold_tps[k] = sum(int(x >= threshold) for x in vs.values()) |
| | micro_r = ( |
| | sum(threshold_tps.values()) / sum(num_truth.values()) |
| | if sum(num_truth.values()) > 0 |
| | else 0 |
| | ) |
| | micro_p = ( |
| | sum(threshold_tps.values()) / sum(num_classifications.values()) |
| | if sum(num_classifications.values()) > 0 |
| | else 0 |
| | ) |
| | micro_f1 = _f1(micro_r, micro_p) |
| | macro_rs = list( |
| | threshold_tps.get(k, 0.0) / n if n > 0 else 0 for k, n in num_truth.items() |
| | ) |
| | macro_ps = list( |
| | threshold_tps.get(k, 0.0) / n if n > 0 else 0 |
| | for k, n in num_classifications.items() |
| | ) |
| | macro_r = sum(macro_rs) / len(macro_rs) if len(macro_rs) > 0 else 0 |
| | macro_p = sum(macro_ps) / len(macro_ps) if len(macro_ps) > 0 else 0 |
| | macro_f1 = _f1(macro_r, macro_p) |
| | scores.append( |
| | { |
| | "threshold": threshold, |
| | "micro": {"p": micro_p, "r": micro_r, "f1": micro_f1}, |
| | "macro": {"p": macro_p, "r": macro_r, "f1": macro_f1}, |
| | } |
| | ) |
| | return scores |
| |
|
| |
|
| | def score_hard_rationale_predictions( |
| | truth: List[Rationale], pred: List[Rationale] |
| | ) -> Dict[str, Dict[str, float]]: |
| | """Computes instance (annotation)-level micro/macro averaged F1s""" |
| | scores = dict() |
| | truth = set(truth) |
| | pred = set(pred) |
| | micro_prec = len(truth & pred) / len(pred) |
| | micro_rec = len(truth & pred) / len(truth) |
| | micro_f1 = _f1(micro_prec, micro_rec) |
| | scores["instance_micro"] = { |
| | "p": micro_prec, |
| | "r": micro_rec, |
| | "f1": micro_f1, |
| | } |
| |
|
| | ann_to_rat = _keyed_rationale_from_list(truth) |
| | pred_to_rat = _keyed_rationale_from_list(pred) |
| | instances_to_scores = dict() |
| | for k in set(ann_to_rat.keys()) | (pred_to_rat.keys()): |
| | if len(pred_to_rat.get(k, set())) > 0: |
| | instance_prec = len( |
| | ann_to_rat.get(k, set()) & pred_to_rat.get(k, set()) |
| | ) / len(pred_to_rat[k]) |
| | else: |
| | instance_prec = 0 |
| | if len(ann_to_rat.get(k, set())) > 0: |
| | instance_rec = len( |
| | ann_to_rat.get(k, set()) & pred_to_rat.get(k, set()) |
| | ) / len(ann_to_rat[k]) |
| | else: |
| | instance_rec = 0 |
| | instance_f1 = _f1(instance_prec, instance_rec) |
| | instances_to_scores[k] = { |
| | "p": instance_prec, |
| | "r": instance_rec, |
| | "f1": instance_f1, |
| | } |
| | |
| | macro_prec = sum(instance["p"] for instance in instances_to_scores.values()) / len( |
| | instances_to_scores |
| | ) |
| | macro_rec = sum(instance["r"] for instance in instances_to_scores.values()) / len( |
| | instances_to_scores |
| | ) |
| | macro_f1 = sum(instance["f1"] for instance in instances_to_scores.values()) / len( |
| | instances_to_scores |
| | ) |
| |
|
| | f1_scores = [instance["f1"] for instance in instances_to_scores.values()] |
| | print(macro_f1, np.argsort(f1_scores)[::-1]) |
| |
|
| | scores["instance_macro"] = { |
| | "p": macro_prec, |
| | "r": macro_rec, |
| | "f1": macro_f1, |
| | } |
| | return scores |
| |
|
| |
|
| | def _auprc(truth: Dict[Any, List[bool]], preds: Dict[Any, List[float]]) -> float: |
| | if len(preds) == 0: |
| | return 0.0 |
| | assert len(truth.keys() and preds.keys()) == len(truth.keys()) |
| | aucs = [] |
| | for k, true in truth.items(): |
| | pred = preds[k] |
| | true = [int(t) for t in true] |
| | precision, recall, _ = precision_recall_curve(true, pred) |
| | aucs.append(auc(recall, precision)) |
| | return np.average(aucs) |
| |
|
| |
|
| | def _score_aggregator( |
| | truth: Dict[Any, List[bool]], |
| | preds: Dict[Any, List[float]], |
| | score_function: Callable[[List[float], List[float]], float], |
| | discard_single_class_answers: bool, |
| | ) -> float: |
| | if len(preds) == 0: |
| | return 0.0 |
| | assert len(truth.keys() and preds.keys()) == len(truth.keys()) |
| | scores = [] |
| | for k, true in truth.items(): |
| | pred = preds[k] |
| | if (all(true) or all(not x for x in true)) and discard_single_class_answers: |
| | continue |
| | true = [int(t) for t in true] |
| | scores.append(score_function(true, pred)) |
| | return np.average(scores) |
| |
|
| |
|
| | def score_soft_tokens(paired_scores: List[PositionScoredDocument]) -> Dict[str, float]: |
| | truth = {(ps.ann_id, ps.docid): ps.truths for ps in paired_scores} |
| | pred = {(ps.ann_id, ps.docid): ps.scores for ps in paired_scores} |
| | auprc_score = _auprc(truth, pred) |
| | ap = _score_aggregator(truth, pred, average_precision_score, True) |
| | roc_auc = _score_aggregator(truth, pred, roc_auc_score, True) |
| |
|
| | return { |
| | "auprc": auprc_score, |
| | "average_precision": ap, |
| | "roc_auc_score": roc_auc, |
| | } |
| |
|
| |
|
| | def _instances_aopc( |
| | instances: List[dict], thresholds: List[float], key: str |
| | ) -> Tuple[float, List[float]]: |
| | dataset_scores = [] |
| | for inst in instances: |
| | kls = inst["classification"] |
| | beta_0 = inst["classification_scores"][kls] |
| | instance_scores = [] |
| | for score in filter( |
| | lambda x: x["threshold"] in thresholds, |
| | sorted(inst["thresholded_scores"], key=lambda x: x["threshold"]), |
| | ): |
| | beta_k = score[key][kls] |
| | delta = beta_0 - beta_k |
| | instance_scores.append(delta) |
| | assert len(instance_scores) == len(thresholds) |
| | dataset_scores.append(instance_scores) |
| | dataset_scores = np.array(dataset_scores) |
| | |
| | |
| | |
| | |
| | final_score = np.average(dataset_scores) |
| | position_scores = np.average(dataset_scores, axis=0).tolist() |
| |
|
| | return final_score, position_scores |
| |
|
| |
|
| | def compute_aopc_scores(instances: List[dict], aopc_thresholds: List[float]): |
| | if aopc_thresholds is None: |
| | aopc_thresholds = sorted( |
| | set( |
| | chain.from_iterable( |
| | [x["threshold"] for x in y["thresholded_scores"]] for y in instances |
| | ) |
| | ) |
| | ) |
| | aopc_comprehensiveness_score, aopc_comprehensiveness_points = _instances_aopc( |
| | instances, aopc_thresholds, "comprehensiveness_classification_scores" |
| | ) |
| | aopc_sufficiency_score, aopc_sufficiency_points = _instances_aopc( |
| | instances, aopc_thresholds, "sufficiency_classification_scores" |
| | ) |
| | return ( |
| | aopc_thresholds, |
| | aopc_comprehensiveness_score, |
| | aopc_comprehensiveness_points, |
| | aopc_sufficiency_score, |
| | aopc_sufficiency_points, |
| | ) |
| |
|
| |
|
| | def score_classifications( |
| | instances: List[dict], |
| | annotations: List[Annotation], |
| | docs: Dict[str, List[str]], |
| | aopc_thresholds: List[float], |
| | ) -> Dict[str, float]: |
| | def compute_kl(cls_scores_, faith_scores_): |
| | keys = list(cls_scores_.keys()) |
| | cls_scores_ = [cls_scores_[k] for k in keys] |
| | faith_scores_ = [faith_scores_[k] for k in keys] |
| | return entropy(faith_scores_, cls_scores_) |
| |
|
| | labels = list(set(x.classification for x in annotations)) |
| | label_to_int = {l: i for i, l in enumerate(labels)} |
| | key_to_instances = {inst["annotation_id"]: inst for inst in instances} |
| | truth = [] |
| | predicted = [] |
| | for ann in annotations: |
| | truth.append(label_to_int[ann.classification]) |
| | inst = key_to_instances[ann.annotation_id] |
| | predicted.append(label_to_int[inst["classification"]]) |
| | classification_scores = classification_report( |
| | truth, predicted, output_dict=True, target_names=labels, digits=3 |
| | ) |
| | accuracy = accuracy_score(truth, predicted) |
| | if "comprehensiveness_classification_scores" in instances[0]: |
| | comprehensiveness_scores = [ |
| | x["classification_scores"][x["classification"]] |
| | - x["comprehensiveness_classification_scores"][x["classification"]] |
| | for x in instances |
| | ] |
| | comprehensiveness_score = np.average(comprehensiveness_scores) |
| | else: |
| | comprehensiveness_score = None |
| | comprehensiveness_scores = None |
| |
|
| | if "sufficiency_classification_scores" in instances[0]: |
| | sufficiency_scores = [ |
| | x["classification_scores"][x["classification"]] |
| | - x["sufficiency_classification_scores"][x["classification"]] |
| | for x in instances |
| | ] |
| | sufficiency_score = np.average(sufficiency_scores) |
| | else: |
| | sufficiency_score = None |
| | sufficiency_scores = None |
| |
|
| | if "comprehensiveness_classification_scores" in instances[0]: |
| | comprehensiveness_entropies = [ |
| | entropy(list(x["classification_scores"].values())) |
| | - entropy(list(x["comprehensiveness_classification_scores"].values())) |
| | for x in instances |
| | ] |
| | comprehensiveness_entropy = np.average(comprehensiveness_entropies) |
| | comprehensiveness_kl = np.average( |
| | list( |
| | compute_kl( |
| | x["classification_scores"], |
| | x["comprehensiveness_classification_scores"], |
| | ) |
| | for x in instances |
| | ) |
| | ) |
| | else: |
| | comprehensiveness_entropies = None |
| | comprehensiveness_kl = None |
| | comprehensiveness_entropy = None |
| |
|
| | if "sufficiency_classification_scores" in instances[0]: |
| | sufficiency_entropies = [ |
| | entropy(list(x["classification_scores"].values())) |
| | - entropy(list(x["sufficiency_classification_scores"].values())) |
| | for x in instances |
| | ] |
| | sufficiency_entropy = np.average(sufficiency_entropies) |
| | sufficiency_kl = np.average( |
| | list( |
| | compute_kl( |
| | x["classification_scores"], x["sufficiency_classification_scores"] |
| | ) |
| | for x in instances |
| | ) |
| | ) |
| | else: |
| | sufficiency_entropies = None |
| | sufficiency_kl = None |
| | sufficiency_entropy = None |
| |
|
| | if "thresholded_scores" in instances[0]: |
| | ( |
| | aopc_thresholds, |
| | aopc_comprehensiveness_score, |
| | aopc_comprehensiveness_points, |
| | aopc_sufficiency_score, |
| | aopc_sufficiency_points, |
| | ) = compute_aopc_scores(instances, aopc_thresholds) |
| | else: |
| | ( |
| | aopc_thresholds, |
| | aopc_comprehensiveness_score, |
| | aopc_comprehensiveness_points, |
| | aopc_sufficiency_score, |
| | aopc_sufficiency_points, |
| | ) = (None, None, None, None, None) |
| | if "tokens_to_flip" in instances[0]: |
| | token_percentages = [] |
| | for ann in annotations: |
| | |
| | docids = set(ev.docid for ev in chain.from_iterable(ann.evidences)) |
| | inst = key_to_instances[ann.annotation_id] |
| | tokens = inst["tokens_to_flip"] |
| | doc_lengths = sum(len(docs[d]) for d in docids) |
| | token_percentages.append(tokens / doc_lengths) |
| | token_percentages = np.average(token_percentages) |
| | else: |
| | token_percentages = None |
| |
|
| | return { |
| | "accuracy": accuracy, |
| | "prf": classification_scores, |
| | "comprehensiveness": comprehensiveness_score, |
| | "sufficiency": sufficiency_score, |
| | "comprehensiveness_entropy": comprehensiveness_entropy, |
| | "comprehensiveness_kl": comprehensiveness_kl, |
| | "sufficiency_entropy": sufficiency_entropy, |
| | "sufficiency_kl": sufficiency_kl, |
| | "aopc_thresholds": aopc_thresholds, |
| | "comprehensiveness_aopc": aopc_comprehensiveness_score, |
| | "comprehensiveness_aopc_points": aopc_comprehensiveness_points, |
| | "sufficiency_aopc": aopc_sufficiency_score, |
| | "sufficiency_aopc_points": aopc_sufficiency_points, |
| | } |
| |
|
| |
|
| | def verify_instance(instance: dict, docs: Dict[str, list], thresholds: Set[float]): |
| | error = False |
| | docids = [] |
| | |
| | |
| | |
| | |
| |
|
| | for rat in instance["rationales"]: |
| | docid = rat["docid"] |
| | if docid not in docid: |
| | error = True |
| | logging.info( |
| | f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} could not be found as a preprocessed document! Gave up on additional processing.' |
| | ) |
| | continue |
| | doc_length = len(docs[docid]) |
| | for h1 in rat.get("hard_rationale_predictions", []): |
| | |
| | |
| | for h2 in rat.get("hard_rationale_predictions", []): |
| | if h1 == h2: |
| | continue |
| | if ( |
| | len( |
| | set(range(h1["start_token"], h1["end_token"])) |
| | & set(range(h2["start_token"], h2["end_token"])) |
| | ) |
| | > 0 |
| | ): |
| | logging.info( |
| | f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} {h1} and {h2} overlap!' |
| | ) |
| | error = True |
| | if h1["start_token"] > doc_length: |
| | logging.info( |
| | f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} received an impossible tokenspan: {h1} for a document of length {doc_length}' |
| | ) |
| | error = True |
| | if h1["end_token"] > doc_length: |
| | logging.info( |
| | f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} received an impossible tokenspan: {h1} for a document of length {doc_length}' |
| | ) |
| | error = True |
| | |
| | |
| | soft_rationale_predictions = rat.get("soft_rationale_predictions", []) |
| | if ( |
| | len(soft_rationale_predictions) > 0 |
| | and len(soft_rationale_predictions) != doc_length |
| | ): |
| | logging.info( |
| | f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} expected classifications for {doc_length} tokens but have them for {len(soft_rationale_predictions)} tokens instead!' |
| | ) |
| | error = True |
| |
|
| | |
| | docids = Counter(docids) |
| | for docid, count in docids.items(): |
| | if count > 1: |
| | error = True |
| | logging.info( |
| | 'Error! For instance annotation={instance["annotation_id"]}, docid={docid} appear {count} times, may only appear once!' |
| | ) |
| |
|
| | classification = instance.get("classification", "") |
| | if not isinstance(classification, str): |
| | logging.info( |
| | f'Error! For instance annotation={instance["annotation_id"]}, classification field {classification} is not a string!' |
| | ) |
| | error = True |
| | classification_scores = instance.get("classification_scores", dict()) |
| | if not isinstance(classification_scores, dict): |
| | logging.info( |
| | f'Error! For instance annotation={instance["annotation_id"]}, classification_scores field {classification_scores} is not a dict!' |
| | ) |
| | error = True |
| | comprehensiveness_classification_scores = instance.get( |
| | "comprehensiveness_classification_scores", dict() |
| | ) |
| | if not isinstance(comprehensiveness_classification_scores, dict): |
| | logging.info( |
| | f'Error! For instance annotation={instance["annotation_id"]}, comprehensiveness_classification_scores field {comprehensiveness_classification_scores} is not a dict!' |
| | ) |
| | error = True |
| | sufficiency_classification_scores = instance.get( |
| | "sufficiency_classification_scores", dict() |
| | ) |
| | if not isinstance(sufficiency_classification_scores, dict): |
| | logging.info( |
| | f'Error! For instance annotation={instance["annotation_id"]}, sufficiency_classification_scores field {sufficiency_classification_scores} is not a dict!' |
| | ) |
| | error = True |
| | if ("classification" in instance) != ("classification_scores" in instance): |
| | logging.info( |
| | f'Error! For instance annotation={instance["annotation_id"]}, when providing a classification, you must also provide classification scores!' |
| | ) |
| | error = True |
| | if ("comprehensiveness_classification_scores" in instance) and not ( |
| | "classification" in instance |
| | ): |
| | logging.info( |
| | f'Error! For instance annotation={instance["annotation_id"]}, when providing a classification, you must also provide a comprehensiveness_classification_score' |
| | ) |
| | error = True |
| | if ("sufficiency_classification_scores" in instance) and not ( |
| | "classification_scores" in instance |
| | ): |
| | logging.info( |
| | f'Error! For instance annotation={instance["annotation_id"]}, when providing a sufficiency_classification_score, you must also provide a classification score!' |
| | ) |
| | error = True |
| | if "thresholded_scores" in instance: |
| | instance_thresholds = set( |
| | x["threshold"] for x in instance["thresholded_scores"] |
| | ) |
| | if instance_thresholds != thresholds: |
| | error = True |
| | logging.info( |
| | 'Error: {instance["thresholded_scores"]} has thresholds that differ from previous thresholds: {thresholds}' |
| | ) |
| | if ( |
| | "comprehensiveness_classification_scores" not in instance |
| | or "sufficiency_classification_scores" not in instance |
| | or "classification" not in instance |
| | or "classification_scores" not in instance |
| | ): |
| | error = True |
| | logging.info( |
| | "Error: {instance} must have comprehensiveness_classification_scores, sufficiency_classification_scores, classification, and classification_scores defined when including thresholded scores" |
| | ) |
| | if not all( |
| | "sufficiency_classification_scores" in x |
| | for x in instance["thresholded_scores"] |
| | ): |
| | error = True |
| | logging.info( |
| | "Error: {instance} must have sufficiency_classification_scores for every threshold" |
| | ) |
| | if not all( |
| | "comprehensiveness_classification_scores" in x |
| | for x in instance["thresholded_scores"] |
| | ): |
| | error = True |
| | logging.info( |
| | "Error: {instance} must have comprehensiveness_classification_scores for every threshold" |
| | ) |
| | return error |
| |
|
| |
|
| | def verify_instances(instances: List[dict], docs: Dict[str, list]): |
| | annotation_ids = list(x["annotation_id"] for x in instances) |
| | key_counter = Counter(annotation_ids) |
| | multi_occurrence_annotation_ids = list( |
| | filter(lambda kv: kv[1] > 1, key_counter.items()) |
| | ) |
| | error = False |
| | if len(multi_occurrence_annotation_ids) > 0: |
| | error = True |
| | logging.info( |
| | f"Error in instances: {len(multi_occurrence_annotation_ids)} appear multiple times in the annotations file: {multi_occurrence_annotation_ids}" |
| | ) |
| | failed_validation = set() |
| | instances_with_classification = list() |
| | instances_with_soft_rationale_predictions = list() |
| | instances_with_soft_sentence_predictions = list() |
| | instances_with_comprehensiveness_classifications = list() |
| | instances_with_sufficiency_classifications = list() |
| | instances_with_thresholded_scores = list() |
| | if "thresholded_scores" in instances[0]: |
| | thresholds = set(x["threshold"] for x in instances[0]["thresholded_scores"]) |
| | else: |
| | thresholds = None |
| | for instance in instances: |
| | instance_error = verify_instance(instance, docs, thresholds) |
| | if instance_error: |
| | error = True |
| | failed_validation.add(instance["annotation_id"]) |
| | if instance.get("classification", None) != None: |
| | instances_with_classification.append(instance) |
| | if instance.get("comprehensiveness_classification_scores", None) != None: |
| | instances_with_comprehensiveness_classifications.append(instance) |
| | if instance.get("sufficiency_classification_scores", None) != None: |
| | instances_with_sufficiency_classifications.append(instance) |
| | has_soft_rationales = [] |
| | has_soft_sentences = [] |
| | for rat in instance["rationales"]: |
| | if rat.get("soft_rationale_predictions", None) != None: |
| | has_soft_rationales.append(rat) |
| | if rat.get("soft_sentence_predictions", None) != None: |
| | has_soft_sentences.append(rat) |
| | if len(has_soft_rationales) > 0: |
| | instances_with_soft_rationale_predictions.append(instance) |
| | if len(has_soft_rationales) != len(instance["rationales"]): |
| | error = True |
| | logging.info( |
| | f'Error: instance {instance["annotation"]} has soft rationales for some but not all reported documents!' |
| | ) |
| | if len(has_soft_sentences) > 0: |
| | instances_with_soft_sentence_predictions.append(instance) |
| | if len(has_soft_sentences) != len(instance["rationales"]): |
| | error = True |
| | logging.info( |
| | f'Error: instance {instance["annotation"]} has soft sentences for some but not all reported documents!' |
| | ) |
| | if "thresholded_scores" in instance: |
| | instances_with_thresholded_scores.append(instance) |
| | logging.info( |
| | f"Error in instances: {len(failed_validation)} instances fail validation: {failed_validation}" |
| | ) |
| | if len(instances_with_classification) != 0 and len( |
| | instances_with_classification |
| | ) != len(instances): |
| | logging.info( |
| | f"Either all {len(instances)} must have a classification or none may, instead {len(instances_with_classification)} do!" |
| | ) |
| | error = True |
| | if len(instances_with_soft_sentence_predictions) != 0 and len( |
| | instances_with_soft_sentence_predictions |
| | ) != len(instances): |
| | logging.info( |
| | f"Either all {len(instances)} must have a sentence prediction or none may, instead {len(instances_with_soft_sentence_predictions)} do!" |
| | ) |
| | error = True |
| | if len(instances_with_soft_rationale_predictions) != 0 and len( |
| | instances_with_soft_rationale_predictions |
| | ) != len(instances): |
| | logging.info( |
| | f"Either all {len(instances)} must have a soft rationale prediction or none may, instead {len(instances_with_soft_rationale_predictions)} do!" |
| | ) |
| | error = True |
| | if len(instances_with_comprehensiveness_classifications) != 0 and len( |
| | instances_with_comprehensiveness_classifications |
| | ) != len(instances): |
| | error = True |
| | logging.info( |
| | f"Either all {len(instances)} must have a comprehensiveness classification or none may, instead {len(instances_with_comprehensiveness_classifications)} do!" |
| | ) |
| | if len(instances_with_sufficiency_classifications) != 0 and len( |
| | instances_with_sufficiency_classifications |
| | ) != len(instances): |
| | error = True |
| | logging.info( |
| | f"Either all {len(instances)} must have a sufficiency classification or none may, instead {len(instances_with_sufficiency_classifications)} do!" |
| | ) |
| | if len(instances_with_thresholded_scores) != 0 and len( |
| | instances_with_thresholded_scores |
| | ) != len(instances): |
| | error = True |
| | logging.info( |
| | f"Either all {len(instances)} must have thresholded scores or none may, instead {len(instances_with_thresholded_scores)} do!" |
| | ) |
| | if error: |
| | raise ValueError( |
| | "Some instances are invalid, please fix your formatting and try again" |
| | ) |
| |
|
| |
|
| | def _has_hard_predictions(results: List[dict]) -> bool: |
| | |
| | return ( |
| | "rationales" in results[0] |
| | and len(results[0]["rationales"]) > 0 |
| | and "hard_rationale_predictions" in results[0]["rationales"][0] |
| | and results[0]["rationales"][0]["hard_rationale_predictions"] is not None |
| | and len(results[0]["rationales"][0]["hard_rationale_predictions"]) > 0 |
| | ) |
| |
|
| |
|
| | def _has_soft_predictions(results: List[dict]) -> bool: |
| | |
| | return ( |
| | "rationales" in results[0] |
| | and len(results[0]["rationales"]) > 0 |
| | and "soft_rationale_predictions" in results[0]["rationales"][0] |
| | and results[0]["rationales"][0]["soft_rationale_predictions"] is not None |
| | ) |
| |
|
| |
|
| | def _has_soft_sentence_predictions(results: List[dict]) -> bool: |
| | |
| | return ( |
| | "rationales" in results[0] |
| | and len(results[0]["rationales"]) > 0 |
| | and "soft_sentence_predictions" in results[0]["rationales"][0] |
| | and results[0]["rationales"][0]["soft_sentence_predictions"] is not None |
| | ) |
| |
|
| |
|
| | def _has_classifications(results: List[dict]) -> bool: |
| | |
| | return "classification" in results[0] and results[0]["classification"] is not None |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="""Computes rationale and final class classification scores""", |
| | formatter_class=argparse.RawTextHelpFormatter, |
| | ) |
| | parser.add_argument( |
| | "--data_dir", |
| | dest="data_dir", |
| | required=True, |
| | help="Which directory contains a {train,val,test}.jsonl file?", |
| | ) |
| | parser.add_argument( |
| | "--split", |
| | dest="split", |
| | required=True, |
| | help="Which of {train,val,test} are we scoring on?", |
| | ) |
| | parser.add_argument( |
| | "--strict", |
| | dest="strict", |
| | required=False, |
| | action="store_true", |
| | default=False, |
| | help="Do we perform strict scoring?", |
| | ) |
| | parser.add_argument( |
| | "--results", |
| | dest="results", |
| | required=True, |
| | help="""Results File |
| | Contents are expected to be jsonl of: |
| | { |
| | "annotation_id": str, required |
| | # these classifications *must not* overlap |
| | "rationales": List[ |
| | { |
| | "docid": str, required |
| | "hard_rationale_predictions": List[{ |
| | "start_token": int, inclusive, required |
| | "end_token": int, exclusive, required |
| | }], optional, |
| | # token level classifications, a value must be provided per-token |
| | # in an ideal world, these correspond to the hard-decoding above. |
| | "soft_rationale_predictions": List[float], optional. |
| | # sentence level classifications, a value must be provided for every |
| | # sentence in each document, or not at all |
| | "soft_sentence_predictions": List[float], optional. |
| | } |
| | ], |
| | # the classification the model made for the overall classification task |
| | "classification": str, optional |
| | # A probability distribution output by the model. We require this to be normalized. |
| | "classification_scores": Dict[str, float], optional |
| | # The next two fields are measures for how faithful your model is (the |
| | # rationales it predicts are in some sense causal of the prediction), and |
| | # how sufficient they are. We approximate a measure for comprehensiveness by |
| | # asking that you remove the top k%% of tokens from your documents, |
| | # running your models again, and reporting the score distribution in the |
| | # "comprehensiveness_classification_scores" field. |
| | # We approximate a measure of sufficiency by asking exactly the converse |
| | # - that you provide model distributions on the removed k%% tokens. |
| | # 'k' is determined by human rationales, and is documented in our paper. |
| | # You should determine which of these tokens to remove based on some kind |
| | # of information about your model: gradient based, attention based, other |
| | # interpretability measures, etc. |
| | # scores per class having removed k%% of the data, where k is determined by human comprehensive rationales |
| | "comprehensiveness_classification_scores": Dict[str, float], optional |
| | # scores per class having access to only k%% of the data, where k is determined by human comprehensive rationales |
| | "sufficiency_classification_scores": Dict[str, float], optional |
| | # the number of tokens required to flip the prediction - see "Is Attention Interpretable" by Serrano and Smith. |
| | "tokens_to_flip": int, optional |
| | "thresholded_scores": List[{ |
| | "threshold": float, required, |
| | "comprehensiveness_classification_scores": like "classification_scores" |
| | "sufficiency_classification_scores": like "classification_scores" |
| | }], optional. if present, then "classification" and "classification_scores" must be present |
| | } |
| | When providing one of the optional fields, it must be provided for *every* instance. |
| | The classification, classification_score, and comprehensiveness_classification_scores |
| | must together be present for every instance or absent for every instance. |
| | """, |
| | ) |
| | parser.add_argument( |
| | "--iou_thresholds", |
| | dest="iou_thresholds", |
| | required=False, |
| | nargs="+", |
| | type=float, |
| | default=[0.5], |
| | help="""Thresholds for IOU scoring. |
| | |
| | These are used for "soft" or partial match scoring of rationale spans. |
| | A span is considered a match if the size of the intersection of the prediction |
| | and the annotation, divided by the union of the two spans, is larger than |
| | the IOU threshold. This score can be computed for arbitrary thresholds. |
| | """, |
| | ) |
| | parser.add_argument( |
| | "--score_file", |
| | dest="score_file", |
| | required=False, |
| | default=None, |
| | help="Where to write results?", |
| | ) |
| | parser.add_argument( |
| | "--aopc_thresholds", |
| | nargs="+", |
| | required=False, |
| | type=float, |
| | default=[0.01, 0.05, 0.1, 0.2, 0.5], |
| | help="Thresholds for AOPC Thresholds", |
| | ) |
| | args = parser.parse_args() |
| | results = load_jsonl(args.results) |
| | docids = set( |
| | chain.from_iterable( |
| | [rat["docid"] for rat in res["rationales"]] for res in results |
| | ) |
| | ) |
| | docs = load_flattened_documents(args.data_dir, docids) |
| | verify_instances(results, docs) |
| | |
| | annotations = annotations_from_jsonl( |
| | os.path.join(args.data_dir, args.split + ".jsonl") |
| | ) |
| | docids |= set( |
| | chain.from_iterable( |
| | (ev.docid for ev in chain.from_iterable(ann.evidences)) |
| | for ann in annotations |
| | ) |
| | ) |
| |
|
| | has_final_predictions = _has_classifications(results) |
| | scores = dict() |
| | if args.strict: |
| | if not args.iou_thresholds: |
| | raise ValueError( |
| | "--iou_thresholds must be provided when running strict scoring" |
| | ) |
| | if not has_final_predictions: |
| | raise ValueError( |
| | "We must have a 'classification', 'classification_score', and 'comprehensiveness_classification_score' field in order to perform scoring!" |
| | ) |
| | |
| | if _has_hard_predictions(results): |
| | truth = list( |
| | chain.from_iterable(Rationale.from_annotation(ann) for ann in annotations) |
| | ) |
| | pred = list( |
| | chain.from_iterable(Rationale.from_instance(inst) for inst in results) |
| | ) |
| | if args.iou_thresholds is not None: |
| | iou_scores = partial_match_score(truth, pred, args.iou_thresholds) |
| | scores["iou_scores"] = iou_scores |
| | |
| | rationale_level_prf = score_hard_rationale_predictions(truth, pred) |
| | scores["rationale_prf"] = rationale_level_prf |
| | token_level_truth = list( |
| | chain.from_iterable(rat.to_token_level() for rat in truth) |
| | ) |
| | token_level_pred = list( |
| | chain.from_iterable(rat.to_token_level() for rat in pred) |
| | ) |
| | token_level_prf = score_hard_rationale_predictions( |
| | token_level_truth, token_level_pred |
| | ) |
| | scores["token_prf"] = token_level_prf |
| | else: |
| | logging.info("No hard predictions detected, skipping rationale scoring") |
| |
|
| | if _has_soft_predictions(results): |
| | flattened_documents = load_flattened_documents(args.data_dir, docids) |
| | paired_scoring = PositionScoredDocument.from_results( |
| | results, annotations, flattened_documents, use_tokens=True |
| | ) |
| | token_scores = score_soft_tokens(paired_scoring) |
| | scores["token_soft_metrics"] = token_scores |
| | else: |
| | logging.info("No soft predictions detected, skipping rationale scoring") |
| |
|
| | if _has_soft_sentence_predictions(results): |
| | documents = load_documents(args.data_dir, docids) |
| | paired_scoring = PositionScoredDocument.from_results( |
| | results, annotations, documents, use_tokens=False |
| | ) |
| | sentence_scores = score_soft_tokens(paired_scoring) |
| | scores["sentence_soft_metrics"] = sentence_scores |
| | else: |
| | logging.info( |
| | "No sentence level predictions detected, skipping sentence-level diagnostic" |
| | ) |
| |
|
| | if has_final_predictions: |
| | flattened_documents = load_flattened_documents(args.data_dir, docids) |
| | class_results = score_classifications( |
| | results, annotations, flattened_documents, args.aopc_thresholds |
| | ) |
| | scores["classification_scores"] = class_results |
| | else: |
| | logging.info("No classification scores detected, skipping classification") |
| |
|
| | pprint.pprint(scores) |
| |
|
| | if args.score_file: |
| | with open(args.score_file, "w") as of: |
| | json.dump(scores, of, indent=4, sort_keys=True) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|