File size: 5,325 Bytes
8797abf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
"""
Topological landscapes: simplex heatmaps and Sperner path visualization.
For 3 objectives, plots the 2D simplex (triangle) with label regions and the walk path.
"""
from __future__ import annotations
import numpy as np
try:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
_HAS_MPL = True
except ImportError:
_HAS_MPL = False
def _simplex_to_xy(weights: np.ndarray) -> tuple[float, float]:
"""Map barycentric (w0, w1, w2) to 2D triangle coords. w0+w1+w2=1."""
w1, w2 = weights[1], weights[2]
x = w1 + w2 * 0.5
y = w2 * (3 ** 0.5) * 0.5
return x, y
def _grid_3simplex(n: int = 30) -> tuple[np.ndarray, np.ndarray]:
"""Return (weights, xy) for a grid on the 3-simplex. weights shape (M, 3)."""
pts = []
for i in range(n + 1):
for j in range(n + 1 - i):
k = n - i - j
w = np.array([i, j, k], dtype=float) / n
pts.append(w)
weights = np.array(pts)
xy = np.array([_simplex_to_xy(w) for w in weights])
return weights, xy
def plot_simplex_heatmap(
oracle_label_fn,
n_grid: int = 40,
ax=None,
title: str = "Sperner labeling (winner per point)",
) -> "plt.Axes":
"""
Plot the 2D simplex colored by the oracle label (0, 1, or 2) at each grid point.
oracle_label_fn(weights) -> int, where weights is array of shape (3,) summing to 1.
"""
if not _HAS_MPL:
raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
weights, xy = _grid_3simplex(n_grid)
labels = np.array([oracle_label_fn(w) for w in weights], dtype=int)
if ax is None:
_, ax = plt.subplots(1, 1, figsize=(6, 5))
# Triangulate: simple regular grid of small triangles
tri_x, tri_y = xy[:, 0], xy[:, 1]
n = n_grid
triangles = []
for i in range(n):
for j in range(n - i):
k = n - i - j
if k <= 0:
continue
# indices in pts: row i has indices starting at i*(n+1) - i*(i-1)//2 approx
def idx(pi, pj):
return pi * (n + 1) - (pi * (pi - 1)) // 2 + pj
i0 = idx(i, j)
i1 = idx(i, j + 1)
i2 = idx(i + 1, j)
triangles.append([i0, i1, i2])
if j + 1 + (i + 1) <= n:
i3 = idx(i + 1, j + 1)
triangles.append([i1, i3, i2])
triangles = np.array(triangles)
tri_labels = labels[triangles].max(axis=1)
cmap = ListedColormap(["#1f77b4", "#ff7f0e", "#2ca02c"])
ax.tripcolor(tri_x, tri_y, triangles, tri_labels, cmap=cmap, shading="flat", vmin=0, vmax=2)
# Triangle outline
ax.plot([0, 1, 0.5, 0], [0, 0, (3 ** 0.5) * 0.5, 0], "k-", lw=1.5)
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, (3 ** 0.5) * 0.5 + 0.05)
ax.set_aspect("equal")
ax.set_title(title)
ax.set_xlabel("w1 →")
ax.set_ylabel("w2 ↑")
return ax
def plot_sperner_path(
history: list | np.ndarray,
ax=None,
title: str = "Sperner walk path",
simplex_heatmap_oracle=None,
n_grid_heatmap: int = 25,
) -> "plt.Axes":
"""
Draw the path of the Sperner walk on the 2D simplex (for 3 objectives only).
history: list of arrays of shape (3,) with non-negative weights summing to 1.
If simplex_heatmap_oracle is provided (callable weights -> label), the background
is colored by label; otherwise only the path and simplex outline are drawn.
"""
if not _HAS_MPL:
raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
history = np.asarray(history)
if history.ndim == 1:
history = history.reshape(1, -1)
if history.shape[1] != 3:
raise ValueError("plot_sperner_path requires 3 objectives (history columns = 3)")
if ax is None:
_, ax = plt.subplots(1, 1, figsize=(6, 5))
if simplex_heatmap_oracle is not None:
plot_simplex_heatmap(simplex_heatmap_oracle, n_grid=n_grid_heatmap, ax=ax, title=title)
else:
ax.plot([0, 1, 0.5, 0], [0, 0, (3 ** 0.5) * 0.5, 0], "k-", lw=1.5)
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, (3 ** 0.5) * 0.5 + 0.05)
ax.set_aspect("equal")
ax.set_title(title)
xy = np.array([_simplex_to_xy(w) for w in history])
ax.plot(xy[:, 0], xy[:, 1], "r-", lw=2, alpha=0.9, label="path")
ax.scatter(xy[0, 0], xy[0, 1], c="green", s=80, zorder=5, label="start")
ax.scatter(xy[-1, 0], xy[-1, 1], c="red", s=80, zorder=5, label="end")
ax.legend(loc="upper right")
return ax
def plot_sperner_path_from_solver(solver) -> "plt.Axes":
"""
Plot the last Sperner path using solver._path_history (set after solve()).
For 3 objectives only. solver must have _path_history (list of weight arrays).
"""
if not _HAS_MPL:
raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
path = getattr(solver, "_path_history", None)
if path is None or len(path) == 0:
print("[WARN] No path history found. Did the solver run?")
return None
path = np.asarray(path)
if path.shape[1] != 3:
raise ValueError("plot_sperner_path_from_solver supports only 3 objectives")
return plot_sperner_path(path, title="Sperner walk (from solver)")
|