realTextGPT / visualization.py
asankaran
updated visualization
b2a66a5
from typing import Iterable
class VisualizationDataRecord:
r"""
A data record for storing attribution relevant information
"""
__slots__ = [
"word_attributions",
"pred_prob",
"pred_class",
"true_class",
"attr_class",
"attr_score",
"raw_input_ids",
"convergence_score",
]
def __init__(
self,
word_attributions,
pred_prob,
pred_class,
true_class,
attr_class,
attr_score,
raw_input_ids,
convergence_score,
) -> None:
self.word_attributions = word_attributions
self.pred_prob = pred_prob
self.pred_class = pred_class
self.true_class = true_class
self.attr_class = attr_class
self.attr_score = attr_score
self.raw_input_ids = raw_input_ids
self.convergence_score = convergence_score
def _get_color(attr):
# clip values to prevent CSS errors (Values should be from [-1,1])
attr = max(-1, min(1, attr))
if attr > 0:
hue = 120
sat = 75
lig = 100 - int(50 * attr)
else:
hue = 0
sat = 75
lig = 100 - int(-40 * attr)
return "hsl({}, {}%, {}%)".format(hue, sat, lig)
def format_special_tokens(token):
if token.startswith("<") and token.endswith(">"):
return "#" + token.strip("<>")
return token
def format_word_importances(words, importances):
if importances is None or len(importances) == 0:
return "<td></td>"
assert len(words) <= len(importances)
tags = ["<td>"]
for word, importance in zip(words, importances[: len(words)]):
word = format_special_tokens(word)
color = _get_color(importance)
unwrapped_tag = '<mark style="background-color: {color}; opacity:1.0; \
line-height:1.75"><font color="black"> {word}\
</font></mark>'.format(
color=color, word=word
)
tags.append(unwrapped_tag)
tags.append("</td>")
return "".join(tags)
def visualize_text(
datarecords: Iterable[VisualizationDataRecord], legend: bool = True
):
dom = ["<table width: 100%>"]
rows = [
"<th>Word Importance</th>"
]
for datarecord in datarecords:
rows.append(
"".join(
[
"<tr>",
format_word_importances(
datarecord.raw_input_ids, datarecord.word_attributions
),
"<tr>",
]
)
)
if legend:
dom.append(
'<div style="border-top: 1px solid; margin-top: 5px; \
padding-top: 5px; display: inline-block">'
)
dom.append("<b>Legend: </b>")
for value, label in zip([-1, 0, 1], ["Natural Text", "Neutral", "Synthetic Text"]):
dom.append(
'<span style="display: inline-block; width: 10px; height: 10px; \
border: 1px solid; background-color: \
{value}"></span> {label} '.format(
value=_get_color(value), label=label
)
)
dom.append("</div>")
dom.append("".join(rows))
dom.append("</table>")
dom = "".join(dom)
# html = HTML("".join(dom))
return dom