from typing import Iterable class VisualizationDataRecord: r""" A data record for storing attribution relevant information """ __slots__ = [ "word_attributions", "pred_prob", "pred_class", "true_class", "attr_class", "attr_score", "raw_input_ids", "convergence_score", ] def __init__( self, word_attributions, pred_prob, pred_class, true_class, attr_class, attr_score, raw_input_ids, convergence_score, ) -> None: self.word_attributions = word_attributions self.pred_prob = pred_prob self.pred_class = pred_class self.true_class = true_class self.attr_class = attr_class self.attr_score = attr_score self.raw_input_ids = raw_input_ids self.convergence_score = convergence_score def _get_color(attr): # clip values to prevent CSS errors (Values should be from [-1,1]) attr = max(-1, min(1, attr)) if attr > 0: hue = 120 sat = 75 lig = 100 - int(50 * attr) else: hue = 0 sat = 75 lig = 100 - int(-40 * attr) return "hsl({}, {}%, {}%)".format(hue, sat, lig) def format_special_tokens(token): if token.startswith("<") and token.endswith(">"): return "#" + token.strip("<>") return token def format_word_importances(words, importances): if importances is None or len(importances) == 0: return "
| Word Importance | " ] for datarecord in datarecords: rows.append( "".join( [ "
|---|