from __future__ import annotations import os import re from dataclasses import dataclass from pathlib import Path import streamlit as st PROBE_FILENAME_RE = re.compile( r"^cognitive_map_probe_layer(?P\d+)_(?P[a-z0-9]+)_" r"(?Ppre_reasoning|post_reasoning)_all_(?Pgeneral|size\d+)\.pt$" ) PERSONA_PROBE_DIR_RE = re.compile( r"^(?P[a-z_]+?)(?:_pca(?P\d+))?_layer(?P\d+)$" ) DEFAULT_PROBE_REPO = "project-telos/cognitive_map_probes" DEFAULT_LOCAL_PROBE_DIR = os.environ.get("PERSONA_PROBES_DIR", "artifacts/probes") @dataclass(frozen=True) class ProbeFileMetadata: filename: str layer: int | None model_type: str location: str | None scope: str | None label: str model_name: str | None = None attribute_name: str | None = None def model_probe_dir_name(model_name: str) -> str: return model_name.replace("/", "__") def parse_probe_filename(filename: str) -> ProbeFileMetadata: path = Path(filename) match = PROBE_FILENAME_RE.match(path.name) if match: layer = int(match.group("layer")) model_type = match.group("model_type") location = match.group("location") scope = match.group("scope") scope_label = scope.replace("size", "size ") return ProbeFileMetadata( filename=filename, layer=layer, model_type=model_type, location=location, scope=scope, label=( f"Layer {layer} - {model_type.upper()} - " f"{location.replace('_', ' ')} - {scope_label}" ), ) parent_match = PERSONA_PROBE_DIR_RE.match(path.parent.name) if parent_match and path.name in {"probe.json", "weights.safetensors"}: layer = int(parent_match.group("layer")) probe_kind = parent_match.group("probe_kind") pca = parent_match.group("pca") scope = f"pca{pca}" if pca else None attribute = path.parent.parent.name or "attribute" model_name = path.parts[0].replace("__", "/") if len(path.parts) >= 5 else None label = f"{attribute} - layer {layer} - {probe_kind}" if pca: label += f" (pca{pca})" return ProbeFileMetadata( filename=filename, layer=layer, model_type=probe_kind, location=None, scope=scope, label=label, model_name=model_name, attribute_name=attribute, ) return ProbeFileMetadata( filename=filename, layer=None, model_type="unknown", location=None, scope=None, label=path.stem.replace("_", " "), ) @st.cache_data(show_spinner=False, ttl=300) def list_probe_files(repo_id: str) -> list[str]: from huggingface_hub import list_repo_files return _dedupe_probe_entries(list_repo_files(repo_id, repo_type="model")) @st.cache_data(show_spinner=False, ttl=30) def list_local_probe_files(root_dir: str) -> list[str]: root = Path(root_dir).expanduser() if not root.is_dir(): return [] files = _dedupe_probe_entries( [ str(path.relative_to(root)) for path in root.rglob("*") if path.is_file() and path.name in {"probe.pt", "probe.json", "weights.safetensors"} ] ) return sorted(files, key=_probe_sort_key) @st.cache_data(show_spinner=False, ttl=300) def download_probe_file(repo_id: str, filename: str) -> str: from huggingface_hub import hf_hub_download return hf_hub_download(repo_id, filename, repo_type="model") @st.cache_data(show_spinner=False, ttl=300) def download_probe_json_and_weights(repo_id: str, filename: str) -> tuple[str, str]: from huggingface_hub import hf_hub_download metadata_path = hf_hub_download(repo_id, filename, repo_type="model") weights_name = str(Path(filename).with_name("weights.safetensors")) weights_path = hf_hub_download(repo_id, weights_name, repo_type="model") return metadata_path, weights_path def _probe_sort_key(filename: str) -> tuple[str, str, int, str]: metadata = parse_probe_filename(filename) return ( metadata.model_name or "", metadata.attribute_name or "", metadata.layer if metadata.layer is not None else 10**9, filename, ) def _dedupe_probe_entries(files: list[str]) -> list[str]: by_dir: dict[str, set[str]] = {} standalone: list[str] = [] for filename in files: path = Path(filename) if path.name in {"probe.pt", "probe.json", "weights.safetensors"}: by_dir.setdefault(str(path.parent), set()).add(path.name) elif filename.endswith(".pt"): standalone.append(filename) entries = list(standalone) for directory, names in by_dir.items(): selected = ( "probe.json" if "probe.json" in names else "probe.pt" if "probe.pt" in names else "weights.safetensors" ) entries.append(str(Path(directory) / selected)) return sorted(entries, key=_probe_sort_key)