File size: 2,106 Bytes
275cef9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | from __future__ import annotations
import json, os
import numpy as np
def load_and_parse(hf_token: str | None = None) -> dict:
"""Download viz_data.json from HF dataset and parse. Uses HF_TOKEN env var."""
from huggingface_hub import hf_hub_download
path = hf_hub_download(
repo_id="build-small-hackathon/ofa-viz-data",
filename="viz_data.json",
repo_type="dataset",
token=hf_token or os.environ.get("HF_TOKEN"),
)
return load_from_path(path)
def load_from_path(path: str) -> dict:
"""Parse a local viz_data.json. Used for testing and local runs."""
with open(path) as f:
raw = json.load(f)
return _parse(raw)
def _parse(raw: dict) -> dict:
model_names: list[str] = list(raw["embeddings"].keys()) # preserves order
emb_list, labels = [], []
for name in model_names:
arr = np.array(raw["embeddings"][name], dtype=np.float32)
emb_list.append(arr)
labels.extend([name] * len(arr))
stacked = np.concatenate(emb_list, axis=0) if emb_list else np.empty((0, 1), dtype=np.float32)
return {
"stacked": stacked,
"labels": labels,
"model_names": model_names,
"teacher_names": [n for n in model_names if n != "student"],
"cka": raw.get("cka", {}),
"curves": raw.get("curves", {}),
}
def fit_umap3d(stacked: np.ndarray, n_neighbors: int = 15):
"""Fit a 3D UMAP on stacked embeddings. Returns fitted reducer."""
import umap as umap_lib
if stacked.shape[0] < 2:
raise ValueError(f"UMAP requires at least 2 samples, got {stacked.shape[0]}")
n = min(n_neighbors, stacked.shape[0] - 1)
reducer = umap_lib.UMAP(n_components=3, n_neighbors=n, random_state=42, verbose=False)
reducer.fit(stacked)
return reducer
def make_empty_viz() -> dict:
"""Return a zero-data dict for use when viz_data.json is unavailable."""
return {
"stacked": np.empty((0, 1), dtype=np.float32),
"labels": [],
"model_names": [],
"teacher_names": [],
"cka": {},
"curves": {},
}
|