Nexa_Labs / visualize /plot_embeddings.py
Allanatrix's picture
Upload 57 files
d8328bf verified
"""Tools for visualising embedding spaces using UMAP."""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import List
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 - needed for 3D projections
from umap import UMAP
from pipeline.storage import load_corpus, load_embeddings
DEFAULT_INDEX_DIR = Path("index")
DEFAULT_CORPUS_PATH = DEFAULT_INDEX_DIR / "corpus.json"
DEFAULT_EMBEDDINGS_PATH = DEFAULT_INDEX_DIR / "embeddings.npy"
def parse_args(argv: List[str] | None = None) -> argparse.Namespace:
"""Parse command-line options for the visualiser.
Parameters
----------
argv: List[str] | None, default None
Optional argument list override for testing.
Returns
-------
argparse.Namespace
Parsed CLI arguments.
"""
parser = argparse.ArgumentParser(description="Visualise SPECTER2 embeddings in 2D or 3D.")
parser.add_argument(
"--embeddings",
type=Path,
default=DEFAULT_EMBEDDINGS_PATH,
help="Path to embeddings.npy (default: index/embeddings.npy)",
)
parser.add_argument(
"--corpus",
type=Path,
default=DEFAULT_CORPUS_PATH,
help="Path to corpus.json metadata (default: index/corpus.json)",
)
parser.add_argument(
"--dims",
type=int,
choices=(2, 3),
default=2,
help="Number of UMAP dimensions (2 or 3, default: 2)",
)
parser.add_argument(
"--output",
type=Path,
default=None,
help="Optional output path for the generated plot (default: derived from embeddings path)",
)
parser.add_argument(
"--show",
action="store_true",
help="Display the plot interactively after saving",
)
return parser.parse_args(argv)
def plot_embeddings(
embeddings_path: Path,
corpus_path: Path,
dims: int = 2,
output_path: Path | None = None,
show: bool = False,
) -> Path:
"""Create a UMAP projection and save the resulting plot.
Parameters
----------
embeddings_path: Path
Location of the `embeddings.npy` file.
corpus_path: Path
Location of the `corpus.json` file.
dims: int, default 2
Number of UMAP dimensions (2 or 3).
output_path: Path | None, default None
Optional destination for the saved figure. Uses a default if not provided.
show: bool, default False
Whether to display the plot after saving.
Returns
-------
Path
The path to the saved figure.
"""
embeddings = load_embeddings(embeddings_path)
corpus = load_corpus(corpus_path)
if embeddings.shape[0] != len(corpus):
raise ValueError(
"Embeddings and corpus lengths do not match. Ensure the inputs originate from the same build run."
)
reducer = UMAP(n_components=dims, n_neighbors=15, min_dist=0.1, random_state=42)
coordinates = reducer.fit_transform(embeddings)
categories = [metadata.get("categories", []) for metadata in corpus]
primary_labels = [category[0] if category else "unknown" for category in categories]
label_to_index = {label: idx for idx, label in enumerate(sorted(set(primary_labels)))}
colour_indices = np.array([label_to_index[label] for label in primary_labels])
fig = _create_figure(coordinates, colour_indices, primary_labels, label_to_index, dims)
derived_output = embeddings_path.with_name(f"embedding_plot_{dims}d.png")
output = output_path or derived_output
output.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(output, dpi=200, bbox_inches="tight")
if show:
plt.show()
else:
plt.close(fig)
print(f"Saved {dims}D embedding visualisation to {output}")
return output
def _create_figure(
coordinates: np.ndarray,
colour_indices: np.ndarray,
labels: List[str],
label_to_index: dict[str, int],
dims: int,
) -> plt.Figure:
"""Create a matplotlib figure for the requested dimensionality.
Parameters
----------
coordinates: np.ndarray
UMAP-reduced coordinates of shape (n_samples, dims).
colour_indices: np.ndarray
Integer indices representing colour assignments per sample.
labels: List[str]
Primary category labels aligned with the coordinates.
label_to_index: dict[str, int]
Mapping from label names to integer colour indices.
dims: int
Dimensionality of the embedding visualisation (2 or 3).
Returns
-------
plt.Figure
The generated matplotlib figure.
"""
plt.rcdefaults()
fig = plt.figure(figsize=(10, 8))
if dims == 2:
ax = fig.add_subplot(111)
scatter = ax.scatter(
coordinates[:, 0],
coordinates[:, 1],
c=colour_indices,
cmap="tab20",
s=20,
alpha=0.85,
)
ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
else:
ax = fig.add_subplot(111, projection="3d")
scatter = ax.scatter(
coordinates[:, 0],
coordinates[:, 1],
coordinates[:, 2],
c=colour_indices,
cmap="tab20",
s=20,
alpha=0.85,
)
ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_zlabel("UMAP 3")
ax.set_title(f"SPECTER2 Embeddings ({dims}D UMAP)")
# Build a small legend using the primary labels.
unique_labels = sorted(set(labels))
handles = []
for label in unique_labels:
colour_value = label_to_index[label]
rgba = scatter.cmap(scatter.norm(colour_value))
handle = Line2D([0], [0], marker="o", color="w", label=label, markerfacecolor=rgba, markersize=8)
handles.append(handle)
if len(handles) <= 12:
ax.legend(handles=handles, title="Primary Category", bbox_to_anchor=(1.05, 1), loc="upper left")
return fig
def main(argv: List[str] | None = None) -> None:
"""Entry point for the visualisation CLI.
Parameters
----------
argv: List[str] | None, default None
Optional argument override when invoking programmatically.
"""
args = parse_args(argv)
plot_embeddings(
embeddings_path=args.embeddings,
corpus_path=args.corpus,
dims=args.dims,
output_path=args.output,
show=args.show,
)
if __name__ == "__main__": # pragma: no cover - CLI entry point
main()