File size: 2,893 Bytes
3e72399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Small helpers for attribution plotting.

Functions:
- get_color_scale: returns a diverging color scale name or sequence suitable for
  positive/negative attributions (e.g., red ↔ blue).
- format_feature_label: truncates long feature text for axis/legend display.
- create_legend: builds a minimal legend configuration payload based on method/order.
"""

from typing import List, Dict

import numpy as np
import plotly.graph_objects as go
_DEFAULT_SCALES: Dict[str, List] = {
    "rdbu": [
        [0.0, "#67001f"],
        [0.111, "#b2182b"],
        [0.222, "#d6604d"],
        [0.333, "#f4a582"],
        [0.444, "#fddbc7"],
        [0.5, "#f7f7f7"],
        [0.555, "#d1e5f0"],
        [0.666, "#92c5de"],
        [0.777, "#4393c3"],
        [0.888, "#2166ac"],
        [1.0, "#053061"],
    ],
    "coolwarm": [
        [0.0, "#3b4cc0"],
        [0.2, "#6282ea"],
        [0.4, "#9fbfff"],
        [0.5, "#d7d7d7"],
        [0.6, "#f7b799"],
        [0.8, "#ee6a24"],
        [1.0, "#b40426"],
    ],
}


def get_color_scale(method: str = "RdBu") -> List:
    """Return a color scale suitable for attribution visualization."""
    if not method:
        return _DEFAULT_SCALES["rdbu"]

    key = method.lower()
    if key in {"shapley", "banzhaf", "influence"}:
        key = "rdbu"

    return _DEFAULT_SCALES.get(key, _DEFAULT_SCALES["rdbu"])


# def format_feature_label(feature: str, max_length: int = 20) -> str:
#     """Truncate long feature text for display."""
#     feature = feature or ""
#     if len(feature) <= max_length:
#         return feature

#     if max_length <= 3:
#         return feature[:max_length]

#     return f"{feature[:max_length - 3]}..."
def format_feature_label(text: str, max_length: int = 60) -> str:
    text = " ".join(text.split())
    if len(text) <= max_length:
        return text
    return text[: max_length - 1] + "…"

def create_legend(method: str, order: int = 1) -> Dict:
    """Create a legend configuration object."""
    label = f"{method.title()} order-{order}" if method else f"order-{order}"
    return {
        "title": {"text": label},
        "orientation": "h",
        "yanchor": "bottom",
        "y": 1.02,
        "xanchor": "right",
        "x": 1,
        "bgcolor": "rgba(0,0,0,0)",
    }


def matplotlib_to_plotly(fig, *, title: str | None = None, height: int | None = None) -> go.Figure:
    """Convert a Matplotlib figure into a Plotly Image figure."""
    from matplotlib import pyplot as plt

    fig.canvas.draw()
    buffer = np.asarray(fig.canvas.buffer_rgba())
    image = buffer[..., :3]
    plt.close(fig)

    plotly_fig = go.Figure(go.Image(z=image))
    plotly_fig.update_xaxes(visible=False)
    plotly_fig.update_yaxes(visible=False)
    plotly_fig.update_layout(
        title=title,
        margin=dict(l=0, r=0, t=50 if title else 10, b=0),
        height=height,
    )
    return plotly_fig