""" Small helpers for attribution plotting. Functions: - get_color_scale: returns a diverging color scale name or sequence suitable for positive/negative attributions (e.g., red ↔ blue). - format_feature_label: truncates long feature text for axis/legend display. - create_legend: builds a minimal legend configuration payload based on method/order. """ from typing import List, Dict import numpy as np import plotly.graph_objects as go _DEFAULT_SCALES: Dict[str, List] = { "rdbu": [ [0.0, "#67001f"], [0.111, "#b2182b"], [0.222, "#d6604d"], [0.333, "#f4a582"], [0.444, "#fddbc7"], [0.5, "#f7f7f7"], [0.555, "#d1e5f0"], [0.666, "#92c5de"], [0.777, "#4393c3"], [0.888, "#2166ac"], [1.0, "#053061"], ], "coolwarm": [ [0.0, "#3b4cc0"], [0.2, "#6282ea"], [0.4, "#9fbfff"], [0.5, "#d7d7d7"], [0.6, "#f7b799"], [0.8, "#ee6a24"], [1.0, "#b40426"], ], } def get_color_scale(method: str = "RdBu") -> List: """Return a color scale suitable for attribution visualization.""" if not method: return _DEFAULT_SCALES["rdbu"] key = method.lower() if key in {"shapley", "banzhaf", "influence"}: key = "rdbu" return _DEFAULT_SCALES.get(key, _DEFAULT_SCALES["rdbu"]) # def format_feature_label(feature: str, max_length: int = 20) -> str: # """Truncate long feature text for display.""" # feature = feature or "" # if len(feature) <= max_length: # return feature # if max_length <= 3: # return feature[:max_length] # return f"{feature[:max_length - 3]}..." def format_feature_label(text: str, max_length: int = 60) -> str: text = " ".join(text.split()) if len(text) <= max_length: return text return text[: max_length - 1] + "…" def create_legend(method: str, order: int = 1) -> Dict: """Create a legend configuration object.""" label = f"{method.title()} order-{order}" if method else f"order-{order}" return { "title": {"text": label}, "orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1, "bgcolor": "rgba(0,0,0,0)", } def matplotlib_to_plotly(fig, *, title: str | None = None, height: int | None = None) -> go.Figure: """Convert a Matplotlib figure into a Plotly Image figure.""" from matplotlib import pyplot as plt fig.canvas.draw() buffer = np.asarray(fig.canvas.buffer_rgba()) image = buffer[..., :3] plt.close(fig) plotly_fig = go.Figure(go.Image(z=image)) plotly_fig.update_xaxes(visible=False) plotly_fig.update_yaxes(visible=False) plotly_fig.update_layout( title=title, margin=dict(l=0, r=0, t=50 if title else 10, b=0), height=height, ) return plotly_fig