from __future__ import annotations from pathlib import Path from typing import Any import numpy as np import pandas as pd from datasets import Dataset, load_dataset def load_image_split( dataset_name: str, split: str, max_samples: int | None = None, ) -> Dataset: dataset = load_dataset(dataset_name, split=split) if max_samples is not None and max_samples > 0: dataset = dataset.select(range(min(max_samples, len(dataset)))) return dataset def detect_label_names(dataset: Dataset, label_column: str) -> list[str] | None: feature = dataset.features.get(label_column) if feature is not None and hasattr(feature, "names"): return list(feature.names) return None def label_to_name(label: Any, label_names: list[str] | None) -> str: if label_names and isinstance(label, (int, np.integer)) and 0 <= int(label) < len(label_names): return label_names[int(label)] return str(label) def embeddings_to_frame( embeddings: np.ndarray, labels: list[Any], label_names: list[str] | None, sample_ids: list[str], model_name: str, dataset_name: str, split: str, ) -> pd.DataFrame: return pd.DataFrame( { "sample_id": sample_ids, "label": labels, "label_name": [label_to_name(label, label_names) for label in labels], "dataset_name": dataset_name, "split": split, "model_name": model_name, "embedding": [embedding.astype(float).tolist() for embedding in embeddings], } ) def save_embeddings(df: pd.DataFrame, output_dir: str | Path) -> Path: output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) parquet_path = output_path / "embeddings.parquet" df.to_parquet(parquet_path, index=False) return parquet_path def read_embeddings(path: str | Path) -> pd.DataFrame: df = pd.read_parquet(path) required = {"sample_id", "label", "embedding"} missing = required.difference(df.columns) if missing: raise ValueError(f"Embedding file is missing required columns: {sorted(missing)}") return df def embedding_matrix(df: pd.DataFrame) -> np.ndarray: if df.empty: return np.empty((0, 0), dtype=np.float32) return np.vstack(df["embedding"].map(np.asarray).to_list()).astype(np.float32)