Qingpeng Kong
clean initial state
3e72399
"""
Small helpers for attribution plotting.
Functions:
- get_color_scale: returns a diverging color scale name or sequence suitable for
positive/negative attributions (e.g., red ↔ blue).
- format_feature_label: truncates long feature text for axis/legend display.
- create_legend: builds a minimal legend configuration payload based on method/order.
"""
from typing import List, Dict
import numpy as np
import plotly.graph_objects as go
_DEFAULT_SCALES: Dict[str, List] = {
"rdbu": [
[0.0, "#67001f"],
[0.111, "#b2182b"],
[0.222, "#d6604d"],
[0.333, "#f4a582"],
[0.444, "#fddbc7"],
[0.5, "#f7f7f7"],
[0.555, "#d1e5f0"],
[0.666, "#92c5de"],
[0.777, "#4393c3"],
[0.888, "#2166ac"],
[1.0, "#053061"],
],
"coolwarm": [
[0.0, "#3b4cc0"],
[0.2, "#6282ea"],
[0.4, "#9fbfff"],
[0.5, "#d7d7d7"],
[0.6, "#f7b799"],
[0.8, "#ee6a24"],
[1.0, "#b40426"],
],
}
def get_color_scale(method: str = "RdBu") -> List:
"""Return a color scale suitable for attribution visualization."""
if not method:
return _DEFAULT_SCALES["rdbu"]
key = method.lower()
if key in {"shapley", "banzhaf", "influence"}:
key = "rdbu"
return _DEFAULT_SCALES.get(key, _DEFAULT_SCALES["rdbu"])
# def format_feature_label(feature: str, max_length: int = 20) -> str:
# """Truncate long feature text for display."""
# feature = feature or ""
# if len(feature) <= max_length:
# return feature
# if max_length <= 3:
# return feature[:max_length]
# return f"{feature[:max_length - 3]}..."
def format_feature_label(text: str, max_length: int = 60) -> str:
text = " ".join(text.split())
if len(text) <= max_length:
return text
return text[: max_length - 1] + "…"
def create_legend(method: str, order: int = 1) -> Dict:
"""Create a legend configuration object."""
label = f"{method.title()} order-{order}" if method else f"order-{order}"
return {
"title": {"text": label},
"orientation": "h",
"yanchor": "bottom",
"y": 1.02,
"xanchor": "right",
"x": 1,
"bgcolor": "rgba(0,0,0,0)",
}
def matplotlib_to_plotly(fig, *, title: str | None = None, height: int | None = None) -> go.Figure:
"""Convert a Matplotlib figure into a Plotly Image figure."""
from matplotlib import pyplot as plt
fig.canvas.draw()
buffer = np.asarray(fig.canvas.buffer_rgba())
image = buffer[..., :3]
plt.close(fig)
plotly_fig = go.Figure(go.Image(z=image))
plotly_fig.update_xaxes(visible=False)
plotly_fig.update_yaxes(visible=False)
plotly_fig.update_layout(
title=title,
margin=dict(l=0, r=0, t=50 if title else 10, b=0),
height=height,
)
return plotly_fig