File size: 5,325 Bytes
8797abf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
Topological landscapes: simplex heatmaps and Sperner path visualization.
For 3 objectives, plots the 2D simplex (triangle) with label regions and the walk path.
"""
from __future__ import annotations

import numpy as np

try:
    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap
    _HAS_MPL = True
except ImportError:
    _HAS_MPL = False


def _simplex_to_xy(weights: np.ndarray) -> tuple[float, float]:
    """Map barycentric (w0, w1, w2) to 2D triangle coords. w0+w1+w2=1."""
    w1, w2 = weights[1], weights[2]
    x = w1 + w2 * 0.5
    y = w2 * (3 ** 0.5) * 0.5
    return x, y


def _grid_3simplex(n: int = 30) -> tuple[np.ndarray, np.ndarray]:
    """Return (weights, xy) for a grid on the 3-simplex. weights shape (M, 3)."""
    pts = []
    for i in range(n + 1):
        for j in range(n + 1 - i):
            k = n - i - j
            w = np.array([i, j, k], dtype=float) / n
            pts.append(w)
    weights = np.array(pts)
    xy = np.array([_simplex_to_xy(w) for w in weights])
    return weights, xy


def plot_simplex_heatmap(
    oracle_label_fn,
    n_grid: int = 40,
    ax=None,
    title: str = "Sperner labeling (winner per point)",
) -> "plt.Axes":
    """
    Plot the 2D simplex colored by the oracle label (0, 1, or 2) at each grid point.
    oracle_label_fn(weights) -> int, where weights is array of shape (3,) summing to 1.
    """
    if not _HAS_MPL:
        raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
    weights, xy = _grid_3simplex(n_grid)
    labels = np.array([oracle_label_fn(w) for w in weights], dtype=int)
    if ax is None:
        _, ax = plt.subplots(1, 1, figsize=(6, 5))
    # Triangulate: simple regular grid of small triangles
    tri_x, tri_y = xy[:, 0], xy[:, 1]
    n = n_grid
    triangles = []
    for i in range(n):
        for j in range(n - i):
            k = n - i - j
            if k <= 0:
                continue
            # indices in pts: row i has indices starting at i*(n+1) - i*(i-1)//2 approx
            def idx(pi, pj):
                return pi * (n + 1) - (pi * (pi - 1)) // 2 + pj
            i0 = idx(i, j)
            i1 = idx(i, j + 1)
            i2 = idx(i + 1, j)
            triangles.append([i0, i1, i2])
            if j + 1 + (i + 1) <= n:
                i3 = idx(i + 1, j + 1)
                triangles.append([i1, i3, i2])
    triangles = np.array(triangles)
    tri_labels = labels[triangles].max(axis=1)
    cmap = ListedColormap(["#1f77b4", "#ff7f0e", "#2ca02c"])
    ax.tripcolor(tri_x, tri_y, triangles, tri_labels, cmap=cmap, shading="flat", vmin=0, vmax=2)
    # Triangle outline
    ax.plot([0, 1, 0.5, 0], [0, 0, (3 ** 0.5) * 0.5, 0], "k-", lw=1.5)
    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, (3 ** 0.5) * 0.5 + 0.05)
    ax.set_aspect("equal")
    ax.set_title(title)
    ax.set_xlabel("w1 →")
    ax.set_ylabel("w2 ↑")
    return ax


def plot_sperner_path(
    history: list | np.ndarray,
    ax=None,
    title: str = "Sperner walk path",
    simplex_heatmap_oracle=None,
    n_grid_heatmap: int = 25,
) -> "plt.Axes":
    """
    Draw the path of the Sperner walk on the 2D simplex (for 3 objectives only).
    history: list of arrays of shape (3,) with non-negative weights summing to 1.
    If simplex_heatmap_oracle is provided (callable weights -> label), the background
    is colored by label; otherwise only the path and simplex outline are drawn.
    """
    if not _HAS_MPL:
        raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
    history = np.asarray(history)
    if history.ndim == 1:
        history = history.reshape(1, -1)
    if history.shape[1] != 3:
        raise ValueError("plot_sperner_path requires 3 objectives (history columns = 3)")
    if ax is None:
        _, ax = plt.subplots(1, 1, figsize=(6, 5))
    if simplex_heatmap_oracle is not None:
        plot_simplex_heatmap(simplex_heatmap_oracle, n_grid=n_grid_heatmap, ax=ax, title=title)
    else:
        ax.plot([0, 1, 0.5, 0], [0, 0, (3 ** 0.5) * 0.5, 0], "k-", lw=1.5)
        ax.set_xlim(-0.05, 1.05)
        ax.set_ylim(-0.05, (3 ** 0.5) * 0.5 + 0.05)
        ax.set_aspect("equal")
        ax.set_title(title)
    xy = np.array([_simplex_to_xy(w) for w in history])
    ax.plot(xy[:, 0], xy[:, 1], "r-", lw=2, alpha=0.9, label="path")
    ax.scatter(xy[0, 0], xy[0, 1], c="green", s=80, zorder=5, label="start")
    ax.scatter(xy[-1, 0], xy[-1, 1], c="red", s=80, zorder=5, label="end")
    ax.legend(loc="upper right")
    return ax


def plot_sperner_path_from_solver(solver) -> "plt.Axes":
    """
    Plot the last Sperner path using solver._path_history (set after solve()).
    For 3 objectives only. solver must have _path_history (list of weight arrays).
    """
    if not _HAS_MPL:
        raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
    path = getattr(solver, "_path_history", None)
    if path is None or len(path) == 0:
        print("[WARN] No path history found. Did the solver run?")
        return None
    path = np.asarray(path)
    if path.shape[1] != 3:
        raise ValueError("plot_sperner_path_from_solver supports only 3 objectives")
    return plot_sperner_path(path, title="Sperner walk (from solver)")