from typing import Iterable class VisualizationDataRecord: r""" A data record for storing attribution relevant information """ __slots__ = [ "word_attributions", "pred_prob", "pred_class", "true_class", "attr_class", "attr_score", "raw_input_ids", "convergence_score", ] def __init__( self, word_attributions, pred_prob, pred_class, true_class, attr_class, attr_score, raw_input_ids, convergence_score, ) -> None: self.word_attributions = word_attributions self.pred_prob = pred_prob self.pred_class = pred_class self.true_class = true_class self.attr_class = attr_class self.attr_score = attr_score self.raw_input_ids = raw_input_ids self.convergence_score = convergence_score def _get_color(attr): # clip values to prevent CSS errors (Values should be from [-1,1]) attr = max(-1, min(1, attr)) if attr > 0: hue = 120 sat = 75 lig = 100 - int(50 * attr) else: hue = 0 sat = 75 lig = 100 - int(-40 * attr) return "hsl({}, {}%, {}%)".format(hue, sat, lig) def format_special_tokens(token): if token.startswith("<") and token.endswith(">"): return "#" + token.strip("<>") return token def format_word_importances(words, importances): if importances is None or len(importances) == 0: return "" assert len(words) <= len(importances) tags = [""] for word, importance in zip(words, importances[: len(words)]): word = format_special_tokens(word) color = _get_color(importance) unwrapped_tag = ' {word}\ '.format( color=color, word=word ) tags.append(unwrapped_tag) tags.append("") return "".join(tags) def visualize_text( datarecords: Iterable[VisualizationDataRecord], legend: bool = True ): dom = [""] rows = [ "" ] for datarecord in datarecords: rows.append( "".join( [ "", format_word_importances( datarecord.raw_input_ids, datarecord.word_attributions ), "", ] ) ) if legend: dom.append( '
' ) dom.append("Legend: ") for value, label in zip([-1, 0, 1], ["Natural Text", "Neutral", "Synthetic Text"]): dom.append( ' {label} '.format( value=_get_color(value), label=label ) ) dom.append("
") dom.append("".join(rows)) dom.append("
Word Importance
") dom = "".join(dom) # html = HTML("".join(dom)) return dom