Spaces:
Sleeping
Sleeping
File size: 5,163 Bytes
b279884 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | from __future__ import annotations
import os
import re
from dataclasses import dataclass
from pathlib import Path
import streamlit as st
PROBE_FILENAME_RE = re.compile(
r"^cognitive_map_probe_layer(?P<layer>\d+)_(?P<model_type>[a-z0-9]+)_"
r"(?P<location>pre_reasoning|post_reasoning)_all_(?P<scope>general|size\d+)\.pt$"
)
PERSONA_PROBE_DIR_RE = re.compile(
r"^(?P<probe_kind>[a-z_]+?)(?:_pca(?P<pca>\d+))?_layer(?P<layer>\d+)$"
)
DEFAULT_PROBE_REPO = "project-telos/cognitive_map_probes"
DEFAULT_LOCAL_PROBE_DIR = os.environ.get("PERSONA_PROBES_DIR", "artifacts/probes")
@dataclass(frozen=True)
class ProbeFileMetadata:
filename: str
layer: int | None
model_type: str
location: str | None
scope: str | None
label: str
model_name: str | None = None
attribute_name: str | None = None
def model_probe_dir_name(model_name: str) -> str:
return model_name.replace("/", "__")
def parse_probe_filename(filename: str) -> ProbeFileMetadata:
path = Path(filename)
match = PROBE_FILENAME_RE.match(path.name)
if match:
layer = int(match.group("layer"))
model_type = match.group("model_type")
location = match.group("location")
scope = match.group("scope")
scope_label = scope.replace("size", "size ")
return ProbeFileMetadata(
filename=filename,
layer=layer,
model_type=model_type,
location=location,
scope=scope,
label=(
f"Layer {layer} - {model_type.upper()} - "
f"{location.replace('_', ' ')} - {scope_label}"
),
)
parent_match = PERSONA_PROBE_DIR_RE.match(path.parent.name)
if parent_match and path.name in {"probe.json", "weights.safetensors"}:
layer = int(parent_match.group("layer"))
probe_kind = parent_match.group("probe_kind")
pca = parent_match.group("pca")
scope = f"pca{pca}" if pca else None
attribute = path.parent.parent.name or "attribute"
model_name = path.parts[0].replace("__", "/") if len(path.parts) >= 5 else None
label = f"{attribute} - layer {layer} - {probe_kind}"
if pca:
label += f" (pca{pca})"
return ProbeFileMetadata(
filename=filename,
layer=layer,
model_type=probe_kind,
location=None,
scope=scope,
label=label,
model_name=model_name,
attribute_name=attribute,
)
return ProbeFileMetadata(
filename=filename,
layer=None,
model_type="unknown",
location=None,
scope=None,
label=path.stem.replace("_", " "),
)
@st.cache_data(show_spinner=False, ttl=300)
def list_probe_files(repo_id: str) -> list[str]:
from huggingface_hub import list_repo_files
return _dedupe_probe_entries(list_repo_files(repo_id, repo_type="model"))
@st.cache_data(show_spinner=False, ttl=30)
def list_local_probe_files(root_dir: str) -> list[str]:
root = Path(root_dir).expanduser()
if not root.is_dir():
return []
files = _dedupe_probe_entries(
[
str(path.relative_to(root))
for path in root.rglob("*")
if path.is_file()
and path.name in {"probe.pt", "probe.json", "weights.safetensors"}
]
)
return sorted(files, key=_probe_sort_key)
@st.cache_data(show_spinner=False, ttl=300)
def download_probe_file(repo_id: str, filename: str) -> str:
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id, filename, repo_type="model")
@st.cache_data(show_spinner=False, ttl=300)
def download_probe_json_and_weights(repo_id: str, filename: str) -> tuple[str, str]:
from huggingface_hub import hf_hub_download
metadata_path = hf_hub_download(repo_id, filename, repo_type="model")
weights_name = str(Path(filename).with_name("weights.safetensors"))
weights_path = hf_hub_download(repo_id, weights_name, repo_type="model")
return metadata_path, weights_path
def _probe_sort_key(filename: str) -> tuple[str, str, int, str]:
metadata = parse_probe_filename(filename)
return (
metadata.model_name or "",
metadata.attribute_name or "",
metadata.layer if metadata.layer is not None else 10**9,
filename,
)
def _dedupe_probe_entries(files: list[str]) -> list[str]:
by_dir: dict[str, set[str]] = {}
standalone: list[str] = []
for filename in files:
path = Path(filename)
if path.name in {"probe.pt", "probe.json", "weights.safetensors"}:
by_dir.setdefault(str(path.parent), set()).add(path.name)
elif filename.endswith(".pt"):
standalone.append(filename)
entries = list(standalone)
for directory, names in by_dir.items():
selected = (
"probe.json"
if "probe.json" in names
else "probe.pt"
if "probe.pt" in names
else "weights.safetensors"
)
entries.append(str(Path(directory) / selected))
return sorted(entries, key=_probe_sort_key)
|