Spaces:
Paused
Paused
| """Tools for visualising embedding spaces using UMAP.""" | |
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| from typing import List | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from matplotlib.lines import Line2D | |
| from mpl_toolkits.mplot3d import Axes3D # noqa: F401 - needed for 3D projections | |
| from umap import UMAP | |
| from pipeline.storage import load_corpus, load_embeddings | |
| DEFAULT_INDEX_DIR = Path("index") | |
| DEFAULT_CORPUS_PATH = DEFAULT_INDEX_DIR / "corpus.json" | |
| DEFAULT_EMBEDDINGS_PATH = DEFAULT_INDEX_DIR / "embeddings.npy" | |
| def parse_args(argv: List[str] | None = None) -> argparse.Namespace: | |
| """Parse command-line options for the visualiser. | |
| Parameters | |
| ---------- | |
| argv: List[str] | None, default None | |
| Optional argument list override for testing. | |
| Returns | |
| ------- | |
| argparse.Namespace | |
| Parsed CLI arguments. | |
| """ | |
| parser = argparse.ArgumentParser(description="Visualise SPECTER2 embeddings in 2D or 3D.") | |
| parser.add_argument( | |
| "--embeddings", | |
| type=Path, | |
| default=DEFAULT_EMBEDDINGS_PATH, | |
| help="Path to embeddings.npy (default: index/embeddings.npy)", | |
| ) | |
| parser.add_argument( | |
| "--corpus", | |
| type=Path, | |
| default=DEFAULT_CORPUS_PATH, | |
| help="Path to corpus.json metadata (default: index/corpus.json)", | |
| ) | |
| parser.add_argument( | |
| "--dims", | |
| type=int, | |
| choices=(2, 3), | |
| default=2, | |
| help="Number of UMAP dimensions (2 or 3, default: 2)", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| default=None, | |
| help="Optional output path for the generated plot (default: derived from embeddings path)", | |
| ) | |
| parser.add_argument( | |
| "--show", | |
| action="store_true", | |
| help="Display the plot interactively after saving", | |
| ) | |
| return parser.parse_args(argv) | |
| def plot_embeddings( | |
| embeddings_path: Path, | |
| corpus_path: Path, | |
| dims: int = 2, | |
| output_path: Path | None = None, | |
| show: bool = False, | |
| ) -> Path: | |
| """Create a UMAP projection and save the resulting plot. | |
| Parameters | |
| ---------- | |
| embeddings_path: Path | |
| Location of the `embeddings.npy` file. | |
| corpus_path: Path | |
| Location of the `corpus.json` file. | |
| dims: int, default 2 | |
| Number of UMAP dimensions (2 or 3). | |
| output_path: Path | None, default None | |
| Optional destination for the saved figure. Uses a default if not provided. | |
| show: bool, default False | |
| Whether to display the plot after saving. | |
| Returns | |
| ------- | |
| Path | |
| The path to the saved figure. | |
| """ | |
| embeddings = load_embeddings(embeddings_path) | |
| corpus = load_corpus(corpus_path) | |
| if embeddings.shape[0] != len(corpus): | |
| raise ValueError( | |
| "Embeddings and corpus lengths do not match. Ensure the inputs originate from the same build run." | |
| ) | |
| reducer = UMAP(n_components=dims, n_neighbors=15, min_dist=0.1, random_state=42) | |
| coordinates = reducer.fit_transform(embeddings) | |
| categories = [metadata.get("categories", []) for metadata in corpus] | |
| primary_labels = [category[0] if category else "unknown" for category in categories] | |
| label_to_index = {label: idx for idx, label in enumerate(sorted(set(primary_labels)))} | |
| colour_indices = np.array([label_to_index[label] for label in primary_labels]) | |
| fig = _create_figure(coordinates, colour_indices, primary_labels, label_to_index, dims) | |
| derived_output = embeddings_path.with_name(f"embedding_plot_{dims}d.png") | |
| output = output_path or derived_output | |
| output.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(output, dpi=200, bbox_inches="tight") | |
| if show: | |
| plt.show() | |
| else: | |
| plt.close(fig) | |
| print(f"Saved {dims}D embedding visualisation to {output}") | |
| return output | |
| def _create_figure( | |
| coordinates: np.ndarray, | |
| colour_indices: np.ndarray, | |
| labels: List[str], | |
| label_to_index: dict[str, int], | |
| dims: int, | |
| ) -> plt.Figure: | |
| """Create a matplotlib figure for the requested dimensionality. | |
| Parameters | |
| ---------- | |
| coordinates: np.ndarray | |
| UMAP-reduced coordinates of shape (n_samples, dims). | |
| colour_indices: np.ndarray | |
| Integer indices representing colour assignments per sample. | |
| labels: List[str] | |
| Primary category labels aligned with the coordinates. | |
| label_to_index: dict[str, int] | |
| Mapping from label names to integer colour indices. | |
| dims: int | |
| Dimensionality of the embedding visualisation (2 or 3). | |
| Returns | |
| ------- | |
| plt.Figure | |
| The generated matplotlib figure. | |
| """ | |
| plt.rcdefaults() | |
| fig = plt.figure(figsize=(10, 8)) | |
| if dims == 2: | |
| ax = fig.add_subplot(111) | |
| scatter = ax.scatter( | |
| coordinates[:, 0], | |
| coordinates[:, 1], | |
| c=colour_indices, | |
| cmap="tab20", | |
| s=20, | |
| alpha=0.85, | |
| ) | |
| ax.set_xlabel("UMAP 1") | |
| ax.set_ylabel("UMAP 2") | |
| else: | |
| ax = fig.add_subplot(111, projection="3d") | |
| scatter = ax.scatter( | |
| coordinates[:, 0], | |
| coordinates[:, 1], | |
| coordinates[:, 2], | |
| c=colour_indices, | |
| cmap="tab20", | |
| s=20, | |
| alpha=0.85, | |
| ) | |
| ax.set_xlabel("UMAP 1") | |
| ax.set_ylabel("UMAP 2") | |
| ax.set_zlabel("UMAP 3") | |
| ax.set_title(f"SPECTER2 Embeddings ({dims}D UMAP)") | |
| # Build a small legend using the primary labels. | |
| unique_labels = sorted(set(labels)) | |
| handles = [] | |
| for label in unique_labels: | |
| colour_value = label_to_index[label] | |
| rgba = scatter.cmap(scatter.norm(colour_value)) | |
| handle = Line2D([0], [0], marker="o", color="w", label=label, markerfacecolor=rgba, markersize=8) | |
| handles.append(handle) | |
| if len(handles) <= 12: | |
| ax.legend(handles=handles, title="Primary Category", bbox_to_anchor=(1.05, 1), loc="upper left") | |
| return fig | |
| def main(argv: List[str] | None = None) -> None: | |
| """Entry point for the visualisation CLI. | |
| Parameters | |
| ---------- | |
| argv: List[str] | None, default None | |
| Optional argument override when invoking programmatically. | |
| """ | |
| args = parse_args(argv) | |
| plot_embeddings( | |
| embeddings_path=args.embeddings, | |
| corpus_path=args.corpus, | |
| dims=args.dims, | |
| output_path=args.output, | |
| show=args.show, | |
| ) | |
| if __name__ == "__main__": # pragma: no cover - CLI entry point | |
| main() | |