Spaces:
Running
Running
| """Generate t-SNE visualizations from XP matrix CSV outputs. | |
| Input CSV format is the matrix CSV produced by evaluation/run.py where each | |
| cell is formatted as "mean (ci_lower, ci_upper)". | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import re | |
| from pathlib import Path | |
| from typing import Literal | |
| from matplotlib.lines import Line2D | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| try: | |
| from sklearn.manifold import TSNE | |
| except ImportError as exc: # pragma: no cover | |
| raise ImportError( | |
| "plot_xp_csv_tsne.py requires scikit-learn. Install with `pip install scikit-learn`." | |
| ) from exc | |
| MEAN_RE = re.compile(r"^\s*([-+]?\d*\.?\d+)") | |
| AGENT_TYPE_COLORS = { | |
| "heuristic": "#1f77b4", | |
| "ippo": "#ff7f0e", | |
| "comedi": "#2ca02c", | |
| "lbrdiv": "#d62728", | |
| "brdiv": "#9467bd", | |
| "other": "#7f7f7f", | |
| } | |
| def infer_agent_type(label: str) -> str: | |
| label_l = label.lower() | |
| if "ippo" in label_l: | |
| return "ippo" | |
| if "comedi" in label_l: | |
| return "comedi" | |
| if "lbrdiv" in label_l: | |
| return "lbrdiv" | |
| if "brdiv" in label_l: | |
| return "brdiv" | |
| heuristic_markers = ( | |
| "seq_agent", | |
| "independent_agent", | |
| "onion_agent", | |
| "plate_agent", | |
| "heuristic", | |
| ) | |
| if any(marker in label_l for marker in heuristic_markers): | |
| return "heuristic" | |
| return "other" | |
| def parse_matrix(csv_path: Path) -> tuple[list[str], list[str], np.ndarray]: | |
| with csv_path.open(newline="") as f: | |
| reader = csv.reader(f) | |
| rows = list(reader) | |
| if len(rows) < 2 or len(rows[0]) < 2: | |
| raise ValueError(f"CSV must contain a header and at least one row: {csv_path}") | |
| col_labels = rows[0][1:] | |
| row_labels = [] | |
| values = [] | |
| for row in rows[1:]: | |
| row_labels.append(row[0]) | |
| parsed_row = [] | |
| for cell in row[1:]: | |
| match = MEAN_RE.match(cell) | |
| if not match: | |
| raise ValueError(f"Could not parse mean from cell {cell!r} in {csv_path}") | |
| parsed_row.append(float(match.group(1))) | |
| values.append(parsed_row) | |
| return row_labels, col_labels, np.asarray(values, dtype=float) | |
| def _choose_perplexity(n_samples: int, requested: float | None) -> float: | |
| if n_samples < 3: | |
| raise ValueError("Need at least 3 samples for t-SNE.") | |
| max_valid = max(2.0, float(n_samples - 1)) | |
| if requested is not None: | |
| if requested >= n_samples: | |
| raise ValueError( | |
| f"perplexity ({requested}) must be less than number of samples ({n_samples})." | |
| ) | |
| return float(requested) | |
| return min(max_valid, max(5.0, float(n_samples) / 3.0)) | |
| def run_tsne(features: np.ndarray, perplexity: float | None, seed: int) -> np.ndarray: | |
| n_samples = features.shape[0] | |
| chosen_perplexity = _choose_perplexity(n_samples, perplexity) | |
| tsne = TSNE( | |
| n_components=2, | |
| perplexity=chosen_perplexity, | |
| init="pca", | |
| learning_rate="auto", | |
| random_state=seed, | |
| ) | |
| return tsne.fit_transform(features) | |
| def plot_embedding( | |
| coords: np.ndarray, | |
| labels: list[str], | |
| title: str, | |
| out_path: Path, | |
| *, | |
| show_point_labels: bool = True, | |
| show_density: bool = False, | |
| point_size: float = 52, | |
| point_alpha: float = 0.95, | |
| figsize: tuple[float, float] = (10.0, 8.0), | |
| dpi: int = 220, | |
| title_fontsize: int = 14, | |
| axis_fontsize: int = 12, | |
| tick_fontsize: int = 10, | |
| legend_fontsize: int = 10, | |
| legend_title_fontsize: int = 11, | |
| ): | |
| agent_types = [infer_agent_type(label) for label in labels] | |
| point_colors = [ | |
| AGENT_TYPE_COLORS.get(agent_type, AGENT_TYPE_COLORS["other"]) for agent_type in agent_types | |
| ] | |
| plt.figure(figsize=figsize) | |
| # Light density layer gives cluster structure without hiding category colors. | |
| if show_density: | |
| plt.hexbin( | |
| coords[:, 0], | |
| coords[:, 1], | |
| gridsize=28, | |
| cmap="Greys", | |
| bins="log", | |
| mincnt=1, | |
| alpha=0.25, | |
| linewidths=0, | |
| zorder=1, | |
| ) | |
| plt.scatter(coords[:, 0], coords[:, 1], s=point_size, c=point_colors, alpha=point_alpha, zorder=2) | |
| if show_point_labels: | |
| for i, label in enumerate(labels): | |
| plt.annotate(label, (coords[i, 0], coords[i, 1]), fontsize=8, alpha=0.9) | |
| plt.title(title, fontsize=title_fontsize) | |
| plt.xlabel("t-SNE 1", fontsize=axis_fontsize) | |
| plt.ylabel("t-SNE 2", fontsize=axis_fontsize) | |
| plt.xticks(fontsize=tick_fontsize) | |
| plt.yticks(fontsize=tick_fontsize) | |
| legend_order = ["heuristic", "ippo", "comedi", "lbrdiv", "brdiv", "other"] | |
| present_types = [] | |
| for agent_type in legend_order: | |
| if agent_type in agent_types: | |
| present_types.append(agent_type) | |
| if present_types: | |
| handles = [ | |
| Line2D( | |
| [0], | |
| [0], | |
| marker="o", | |
| color="w", | |
| label=agent_type, | |
| markerfacecolor=AGENT_TYPE_COLORS[agent_type], | |
| markersize=8, | |
| ) | |
| for agent_type in present_types | |
| ] | |
| plt.legend( | |
| handles=handles, | |
| title="Agent type", | |
| loc="best", | |
| fontsize=legend_fontsize, | |
| title_fontsize=legend_title_fontsize, | |
| ) | |
| plt.tight_layout() | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| plt.savefig(out_path, dpi=dpi, bbox_inches="tight") | |
| plt.close() | |
| def process_csv( | |
| csv_path: Path, | |
| out_dir: Path | None, | |
| perplexity: float | None, | |
| seed: int, | |
| embedding: Literal["both", "rows", "cols"], | |
| show_point_labels: bool, | |
| show_density: bool, | |
| point_size: float, | |
| point_alpha: float, | |
| figsize: tuple[float, float], | |
| dpi: int, | |
| title_fontsize: int, | |
| axis_fontsize: int, | |
| tick_fontsize: int, | |
| legend_fontsize: int, | |
| legend_title_fontsize: int, | |
| ) -> tuple[Path | None, Path | None]: | |
| row_labels, col_labels, matrix = parse_matrix(csv_path) | |
| row_coords = run_tsne(matrix, perplexity=perplexity, seed=seed) if embedding in {"both", "rows"} else None | |
| col_coords = run_tsne(matrix.T, perplexity=perplexity, seed=seed) if embedding in {"both", "cols"} else None | |
| target_dir = out_dir if out_dir is not None else csv_path.parent | |
| row_out = target_dir / f"{csv_path.stem}_rows_tsne.png" | |
| col_out = target_dir / f"{csv_path.stem}_cols_tsne.png" | |
| if row_coords is not None: | |
| plot_embedding( | |
| row_coords, | |
| row_labels, | |
| f"t-SNE (rows): {csv_path.stem}", | |
| row_out, | |
| show_point_labels=show_point_labels, | |
| show_density=show_density, | |
| point_size=point_size, | |
| point_alpha=point_alpha, | |
| figsize=figsize, | |
| dpi=dpi, | |
| title_fontsize=title_fontsize, | |
| axis_fontsize=axis_fontsize, | |
| tick_fontsize=tick_fontsize, | |
| legend_fontsize=legend_fontsize, | |
| legend_title_fontsize=legend_title_fontsize, | |
| ) | |
| if col_coords is not None: | |
| plot_embedding( | |
| col_coords, | |
| col_labels, | |
| f"t-SNE (columns): {csv_path.stem}", | |
| col_out, | |
| show_point_labels=show_point_labels, | |
| show_density=show_density, | |
| point_size=point_size, | |
| point_alpha=point_alpha, | |
| figsize=figsize, | |
| dpi=dpi, | |
| title_fontsize=title_fontsize, | |
| axis_fontsize=axis_fontsize, | |
| tick_fontsize=tick_fontsize, | |
| legend_fontsize=legend_fontsize, | |
| legend_title_fontsize=legend_title_fontsize, | |
| ) | |
| return (row_out if row_coords is not None else None), (col_out if col_coords is not None else None) | |
| def iter_input_csvs(input_path: Path): | |
| if input_path.is_file(): | |
| yield input_path | |
| return | |
| for csv_path in sorted(input_path.glob("*.csv")): | |
| if csv_path.stem.endswith("_tidy"): | |
| continue | |
| yield csv_path | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "input_path", | |
| type=Path, | |
| help="Path to one XP matrix CSV or directory containing CSVs.", | |
| ) | |
| parser.add_argument( | |
| "--out-dir", | |
| type=Path, | |
| default=None, | |
| help="Directory to write output PNGs. Defaults to CSV directory.", | |
| ) | |
| parser.add_argument( | |
| "--perplexity", | |
| type=float, | |
| default=None, | |
| help="t-SNE perplexity. Must be < number of samples.", | |
| ) | |
| parser.add_argument("--seed", type=int, default=0) | |
| parser.add_argument( | |
| "--embedding", | |
| choices=("both", "rows", "cols"), | |
| default="both", | |
| help="Which embedding(s) to render.", | |
| ) | |
| parser.add_argument( | |
| "--hide-point-labels", | |
| action="store_true", | |
| help="Do not annotate every point label on the plot.", | |
| ) | |
| parser.add_argument( | |
| "--show-density", | |
| action="store_true", | |
| help="Overlay a light hexbin density map behind points.", | |
| ) | |
| parser.add_argument("--point-size", type=float, default=52.0) | |
| parser.add_argument("--point-alpha", type=float, default=0.95) | |
| parser.add_argument("--fig-width", type=float, default=10.0) | |
| parser.add_argument("--fig-height", type=float, default=8.0) | |
| parser.add_argument("--dpi", type=int, default=220) | |
| parser.add_argument("--title-fontsize", type=int, default=14) | |
| parser.add_argument("--axis-fontsize", type=int, default=12) | |
| parser.add_argument("--tick-fontsize", type=int, default=10) | |
| parser.add_argument("--legend-fontsize", type=int, default=10) | |
| parser.add_argument("--legend-title-fontsize", type=int, default=11) | |
| parser.add_argument( | |
| "--publication", | |
| action="store_true", | |
| help="Preset for publication figures: cols only, no per-point labels, larger text and higher DPI.", | |
| ) | |
| args = parser.parse_args() | |
| input_path = args.input_path | |
| if not input_path.exists(): | |
| raise ValueError(f"Input path does not exist: {input_path}") | |
| embedding = args.embedding | |
| show_point_labels = not args.hide_point_labels | |
| show_density = args.show_density | |
| point_size = args.point_size | |
| point_alpha = args.point_alpha | |
| fig_width = args.fig_width | |
| fig_height = args.fig_height | |
| dpi = args.dpi | |
| title_fontsize = args.title_fontsize | |
| axis_fontsize = args.axis_fontsize | |
| tick_fontsize = args.tick_fontsize | |
| legend_fontsize = args.legend_fontsize | |
| legend_title_fontsize = args.legend_title_fontsize | |
| if args.publication: | |
| embedding = "cols" | |
| show_point_labels = False | |
| show_density = True | |
| point_size = 68.0 | |
| dpi = max(dpi, 320) | |
| fig_width = max(fig_width, 12.0) | |
| fig_height = max(fig_height, 9.0) | |
| title_fontsize = max(title_fontsize, 18) | |
| axis_fontsize = max(axis_fontsize, 16) | |
| tick_fontsize = max(tick_fontsize, 13) | |
| legend_fontsize = max(legend_fontsize, 12) | |
| legend_title_fontsize = max(legend_title_fontsize, 13) | |
| generated = 0 | |
| for csv_path in iter_input_csvs(input_path): | |
| row_out, col_out = process_csv( | |
| csv_path, | |
| out_dir=args.out_dir, | |
| perplexity=args.perplexity, | |
| seed=args.seed, | |
| embedding=embedding, | |
| show_point_labels=show_point_labels, | |
| show_density=show_density, | |
| point_size=point_size, | |
| point_alpha=point_alpha, | |
| figsize=(fig_width, fig_height), | |
| dpi=dpi, | |
| title_fontsize=title_fontsize, | |
| axis_fontsize=axis_fontsize, | |
| tick_fontsize=tick_fontsize, | |
| legend_fontsize=legend_fontsize, | |
| legend_title_fontsize=legend_title_fontsize, | |
| ) | |
| if row_out is not None: | |
| print(row_out) | |
| generated += 1 | |
| if col_out is not None: | |
| print(col_out) | |
| generated += 1 | |
| if generated == 0: | |
| raise ValueError(f"No matrix CSV files found in: {input_path}") | |
| if __name__ == "__main__": | |
| main() |