topo-align / equilib /plotting.py
omesbah's picture
chore: initial commit for Hugging Face publication
8797abf
"""
Topological landscapes: simplex heatmaps and Sperner path visualization.
For 3 objectives, plots the 2D simplex (triangle) with label regions and the walk path.
"""
from __future__ import annotations
import numpy as np
try:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
_HAS_MPL = True
except ImportError:
_HAS_MPL = False
def _simplex_to_xy(weights: np.ndarray) -> tuple[float, float]:
"""Map barycentric (w0, w1, w2) to 2D triangle coords. w0+w1+w2=1."""
w1, w2 = weights[1], weights[2]
x = w1 + w2 * 0.5
y = w2 * (3 ** 0.5) * 0.5
return x, y
def _grid_3simplex(n: int = 30) -> tuple[np.ndarray, np.ndarray]:
"""Return (weights, xy) for a grid on the 3-simplex. weights shape (M, 3)."""
pts = []
for i in range(n + 1):
for j in range(n + 1 - i):
k = n - i - j
w = np.array([i, j, k], dtype=float) / n
pts.append(w)
weights = np.array(pts)
xy = np.array([_simplex_to_xy(w) for w in weights])
return weights, xy
def plot_simplex_heatmap(
oracle_label_fn,
n_grid: int = 40,
ax=None,
title: str = "Sperner labeling (winner per point)",
) -> "plt.Axes":
"""
Plot the 2D simplex colored by the oracle label (0, 1, or 2) at each grid point.
oracle_label_fn(weights) -> int, where weights is array of shape (3,) summing to 1.
"""
if not _HAS_MPL:
raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
weights, xy = _grid_3simplex(n_grid)
labels = np.array([oracle_label_fn(w) for w in weights], dtype=int)
if ax is None:
_, ax = plt.subplots(1, 1, figsize=(6, 5))
# Triangulate: simple regular grid of small triangles
tri_x, tri_y = xy[:, 0], xy[:, 1]
n = n_grid
triangles = []
for i in range(n):
for j in range(n - i):
k = n - i - j
if k <= 0:
continue
# indices in pts: row i has indices starting at i*(n+1) - i*(i-1)//2 approx
def idx(pi, pj):
return pi * (n + 1) - (pi * (pi - 1)) // 2 + pj
i0 = idx(i, j)
i1 = idx(i, j + 1)
i2 = idx(i + 1, j)
triangles.append([i0, i1, i2])
if j + 1 + (i + 1) <= n:
i3 = idx(i + 1, j + 1)
triangles.append([i1, i3, i2])
triangles = np.array(triangles)
tri_labels = labels[triangles].max(axis=1)
cmap = ListedColormap(["#1f77b4", "#ff7f0e", "#2ca02c"])
ax.tripcolor(tri_x, tri_y, triangles, tri_labels, cmap=cmap, shading="flat", vmin=0, vmax=2)
# Triangle outline
ax.plot([0, 1, 0.5, 0], [0, 0, (3 ** 0.5) * 0.5, 0], "k-", lw=1.5)
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, (3 ** 0.5) * 0.5 + 0.05)
ax.set_aspect("equal")
ax.set_title(title)
ax.set_xlabel("w1 →")
ax.set_ylabel("w2 ↑")
return ax
def plot_sperner_path(
history: list | np.ndarray,
ax=None,
title: str = "Sperner walk path",
simplex_heatmap_oracle=None,
n_grid_heatmap: int = 25,
) -> "plt.Axes":
"""
Draw the path of the Sperner walk on the 2D simplex (for 3 objectives only).
history: list of arrays of shape (3,) with non-negative weights summing to 1.
If simplex_heatmap_oracle is provided (callable weights -> label), the background
is colored by label; otherwise only the path and simplex outline are drawn.
"""
if not _HAS_MPL:
raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
history = np.asarray(history)
if history.ndim == 1:
history = history.reshape(1, -1)
if history.shape[1] != 3:
raise ValueError("plot_sperner_path requires 3 objectives (history columns = 3)")
if ax is None:
_, ax = plt.subplots(1, 1, figsize=(6, 5))
if simplex_heatmap_oracle is not None:
plot_simplex_heatmap(simplex_heatmap_oracle, n_grid=n_grid_heatmap, ax=ax, title=title)
else:
ax.plot([0, 1, 0.5, 0], [0, 0, (3 ** 0.5) * 0.5, 0], "k-", lw=1.5)
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, (3 ** 0.5) * 0.5 + 0.05)
ax.set_aspect("equal")
ax.set_title(title)
xy = np.array([_simplex_to_xy(w) for w in history])
ax.plot(xy[:, 0], xy[:, 1], "r-", lw=2, alpha=0.9, label="path")
ax.scatter(xy[0, 0], xy[0, 1], c="green", s=80, zorder=5, label="start")
ax.scatter(xy[-1, 0], xy[-1, 1], c="red", s=80, zorder=5, label="end")
ax.legend(loc="upper right")
return ax
def plot_sperner_path_from_solver(solver) -> "plt.Axes":
"""
Plot the last Sperner path using solver._path_history (set after solve()).
For 3 objectives only. solver must have _path_history (list of weight arrays).
"""
if not _HAS_MPL:
raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
path = getattr(solver, "_path_history", None)
if path is None or len(path) == 0:
print("[WARN] No path history found. Did the solver run?")
return None
path = np.asarray(path)
if path.shape[1] != 3:
raise ValueError("plot_sperner_path_from_solver supports only 3 objectives")
return plot_sperner_path(path, title="Sperner walk (from solver)")