File size: 2,360 Bytes
2bc3168 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from datasets import Dataset, load_dataset
def load_image_split(
dataset_name: str,
split: str,
max_samples: int | None = None,
) -> Dataset:
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None and max_samples > 0:
dataset = dataset.select(range(min(max_samples, len(dataset))))
return dataset
def detect_label_names(dataset: Dataset, label_column: str) -> list[str] | None:
feature = dataset.features.get(label_column)
if feature is not None and hasattr(feature, "names"):
return list(feature.names)
return None
def label_to_name(label: Any, label_names: list[str] | None) -> str:
if label_names and isinstance(label, (int, np.integer)) and 0 <= int(label) < len(label_names):
return label_names[int(label)]
return str(label)
def embeddings_to_frame(
embeddings: np.ndarray,
labels: list[Any],
label_names: list[str] | None,
sample_ids: list[str],
model_name: str,
dataset_name: str,
split: str,
) -> pd.DataFrame:
return pd.DataFrame(
{
"sample_id": sample_ids,
"label": labels,
"label_name": [label_to_name(label, label_names) for label in labels],
"dataset_name": dataset_name,
"split": split,
"model_name": model_name,
"embedding": [embedding.astype(float).tolist() for embedding in embeddings],
}
)
def save_embeddings(df: pd.DataFrame, output_dir: str | Path) -> Path:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
parquet_path = output_path / "embeddings.parquet"
df.to_parquet(parquet_path, index=False)
return parquet_path
def read_embeddings(path: str | Path) -> pd.DataFrame:
df = pd.read_parquet(path)
required = {"sample_id", "label", "embedding"}
missing = required.difference(df.columns)
if missing:
raise ValueError(f"Embedding file is missing required columns: {sorted(missing)}")
return df
def embedding_matrix(df: pd.DataFrame) -> np.ndarray:
if df.empty:
return np.empty((0, 0), dtype=np.float32)
return np.vstack(df["embedding"].map(np.asarray).to_list()).astype(np.float32)
|