| """ |
| Small helpers for attribution plotting. |
| |
| Functions: |
| - get_color_scale: returns a diverging color scale name or sequence suitable for |
| positive/negative attributions (e.g., red ↔ blue). |
| - format_feature_label: truncates long feature text for axis/legend display. |
| - create_legend: builds a minimal legend configuration payload based on method/order. |
| """ |
|
|
| from typing import List, Dict |
|
|
| import numpy as np |
| import plotly.graph_objects as go |
| _DEFAULT_SCALES: Dict[str, List] = { |
| "rdbu": [ |
| [0.0, "#67001f"], |
| [0.111, "#b2182b"], |
| [0.222, "#d6604d"], |
| [0.333, "#f4a582"], |
| [0.444, "#fddbc7"], |
| [0.5, "#f7f7f7"], |
| [0.555, "#d1e5f0"], |
| [0.666, "#92c5de"], |
| [0.777, "#4393c3"], |
| [0.888, "#2166ac"], |
| [1.0, "#053061"], |
| ], |
| "coolwarm": [ |
| [0.0, "#3b4cc0"], |
| [0.2, "#6282ea"], |
| [0.4, "#9fbfff"], |
| [0.5, "#d7d7d7"], |
| [0.6, "#f7b799"], |
| [0.8, "#ee6a24"], |
| [1.0, "#b40426"], |
| ], |
| } |
|
|
|
|
| def get_color_scale(method: str = "RdBu") -> List: |
| """Return a color scale suitable for attribution visualization.""" |
| if not method: |
| return _DEFAULT_SCALES["rdbu"] |
|
|
| key = method.lower() |
| if key in {"shapley", "banzhaf", "influence"}: |
| key = "rdbu" |
|
|
| return _DEFAULT_SCALES.get(key, _DEFAULT_SCALES["rdbu"]) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| def format_feature_label(text: str, max_length: int = 60) -> str: |
| text = " ".join(text.split()) |
| if len(text) <= max_length: |
| return text |
| return text[: max_length - 1] + "…" |
|
|
| def create_legend(method: str, order: int = 1) -> Dict: |
| """Create a legend configuration object.""" |
| label = f"{method.title()} order-{order}" if method else f"order-{order}" |
| return { |
| "title": {"text": label}, |
| "orientation": "h", |
| "yanchor": "bottom", |
| "y": 1.02, |
| "xanchor": "right", |
| "x": 1, |
| "bgcolor": "rgba(0,0,0,0)", |
| } |
|
|
|
|
| def matplotlib_to_plotly(fig, *, title: str | None = None, height: int | None = None) -> go.Figure: |
| """Convert a Matplotlib figure into a Plotly Image figure.""" |
| from matplotlib import pyplot as plt |
|
|
| fig.canvas.draw() |
| buffer = np.asarray(fig.canvas.buffer_rgba()) |
| image = buffer[..., :3] |
| plt.close(fig) |
|
|
| plotly_fig = go.Figure(go.Image(z=image)) |
| plotly_fig.update_xaxes(visible=False) |
| plotly_fig.update_yaxes(visible=False) |
| plotly_fig.update_layout( |
| title=title, |
| margin=dict(l=0, r=0, t=50 if title else 10, b=0), |
| height=height, |
| ) |
| return plotly_fig |
|
|