Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +6 -41
src/streamlit_app.py
CHANGED
|
@@ -214,11 +214,6 @@ def build_threshold_graph(H: np.ndarray, top_pct: float = 0.05, use_cosine: bool
|
|
| 214 |
return G
|
| 215 |
|
| 216 |
def percolation_stats(G: nx.Graph) -> Dict[str, float]:
|
| 217 |
-
"""
|
| 218 |
-
Compute percolation observables (φ, #clusters, χ) as in your notebook.
|
| 219 |
-
φ : fraction of nodes in the Giant Connected Component (GCC)
|
| 220 |
-
χ : mean size of components excluding GCC
|
| 221 |
-
"""
|
| 222 |
n = G.number_of_nodes()
|
| 223 |
if n == 0:
|
| 224 |
return dict(phi=0.0, num_clusters=0, chi=0.0, largest_component_size=0, component_sizes=[])
|
|
@@ -283,19 +278,12 @@ def fit_global_anchors(pool: np.ndarray, K: int) -> np.ndarray:
|
|
| 283 |
# ====== Model I/O (hidden states) =============================================================
|
| 284 |
@dataclass
|
| 285 |
class HiddenStatesBundle:
|
| 286 |
-
"""
|
| 287 |
-
Encapsulates a single input's hidden states and metadata.
|
| 288 |
-
hidden_layers: list of np.ndarray of shape (T, D), length = num_layers+1 (incl. embedding)
|
| 289 |
-
tokens : list of token strings of length T
|
| 290 |
-
"""
|
| 291 |
hidden_layers: List[np.ndarray]
|
| 292 |
tokens: List[str]
|
| 293 |
|
| 294 |
|
| 295 |
def load_qwen(model_name: str, device: str, dtype: torch.dtype):
|
| 296 |
-
|
| 297 |
-
Load Qwen with output_hidden_states=True. We use AutoTokenizer for broader compatibility.
|
| 298 |
-
"""
|
| 299 |
print(f"[Load] {model_name} on {device} ({dtype})")
|
| 300 |
config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
|
| 301 |
tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
@@ -308,10 +296,7 @@ def load_qwen(model_name: str, device: str, dtype: torch.dtype):
|
|
| 308 |
|
| 309 |
@torch.no_grad()
|
| 310 |
def extract_hidden_states(model, tokenizer, text: str, max_length: int, device: str) -> HiddenStatesBundle:
|
| 311 |
-
|
| 312 |
-
Run a single forward pass to collect all hidden states (incl. embedding layer).
|
| 313 |
-
Returns CPU numpy arrays to keep GPU memory low.
|
| 314 |
-
"""
|
| 315 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
| 316 |
out = model(**inputs)
|
| 317 |
# Tuple length = num_layers + 1 (embedding)
|
|
@@ -322,25 +307,14 @@ def extract_hidden_states(model, tokenizer, text: str, max_length: int, device:
|
|
| 322 |
|
| 323 |
# ====== LoT-style anchors & features ==========================================================
|
| 324 |
def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray:
|
| 325 |
-
|
| 326 |
-
Fit KMeans cluster centroids on a pooled set of states (from many layers/texts).
|
| 327 |
-
These centroids are "anchors" (LoT-like choices) to build low-dim features:
|
| 328 |
-
f(state) = [dist(state, anchor_j)]_{j=1..K}
|
| 329 |
-
"""
|
| 330 |
print(f"[Anchors] Fitting {K} global centroids on {len(all_states_sampled)} states ...")
|
| 331 |
kmeans = KMeans(n_clusters=K, n_init="auto", random_state=random_state)
|
| 332 |
kmeans.fit(all_states_sampled)
|
| 333 |
return kmeans.cluster_centers_ # (K, D)
|
| 334 |
|
| 335 |
def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 336 |
-
|
| 337 |
-
For states H (N,D) and anchors A (K,D):
|
| 338 |
-
- Compute Euclidean distances to each anchor → Dists (N,K)
|
| 339 |
-
- Convert to soft probabilities with exp(-Dist/T), normalize row-wise → P (N,K)
|
| 340 |
-
- Uncertainty = entropy(P) (cf. LoT Eq. (6))
|
| 341 |
-
- Top-anchor argmin distance for "consistency"-style comparisons (cf. Eq. (5))
|
| 342 |
-
Returns (Dists, P, entropy)
|
| 343 |
-
"""
|
| 344 |
# Distances (N, K)
|
| 345 |
dists = pairwise_distances(H, anchors, metric="euclidean") # (N,K)
|
| 346 |
# Soft assignments
|
|
@@ -361,10 +335,7 @@ def fit_umap_2d(pool: np.ndarray,
|
|
| 361 |
min_dist: float = 0.05,
|
| 362 |
metric: str = "cosine",
|
| 363 |
random_state: int = 42) -> umap.UMAP:
|
| 364 |
-
|
| 365 |
-
Fit UMAP once on a diverse pool across layers to preserve orientation.
|
| 366 |
-
Later layers call .transform() to embed into the SAME 2D space → "MRI stack".
|
| 367 |
-
"""
|
| 368 |
|
| 369 |
reducer = umap.UMAP(n_components=2, n_neighbors=n_neighbors, min_dist=min_dist,
|
| 370 |
metric=metric, random_state=random_state)
|
|
@@ -377,15 +348,11 @@ def fit_umap_3d(all_states: np.ndarray,
|
|
| 377 |
min_dist: float = 0.05,
|
| 378 |
metric: str = "cosine",
|
| 379 |
random_state: int = 42) -> np.ndarray:
|
| 380 |
-
|
| 381 |
-
Fit a global 3D UMAP embedding for all states at once (alternative to slice stack).
|
| 382 |
-
Returns coords_3d (N,3) for the concatenated states passed in.
|
| 383 |
-
"""
|
| 384 |
reducer = umap.UMAP(n_components=3, n_neighbors=n_neighbors, min_dist=min_dist,
|
| 385 |
metric=metric, random_state=random_state)
|
| 386 |
return reducer.fit_transform(all_states)
|
| 387 |
|
| 388 |
-
"""## Define Visualization Function"""
|
| 389 |
|
| 390 |
# ====== Visualization ========================================================================
|
| 391 |
def plotly_3d_layers(xy_layers: List[np.ndarray],
|
|
@@ -531,7 +498,6 @@ def plotly_3d_layers(xy_layers: List[np.ndarray],
|
|
| 531 |
)
|
| 532 |
return fig
|
| 533 |
|
| 534 |
-
"""## Building the pipeline"""
|
| 535 |
|
| 536 |
def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False):
|
| 537 |
seed_everything(42)
|
|
@@ -630,7 +596,6 @@ def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts
|
|
| 630 |
|
| 631 |
return fig, {"percolation": percolation, "tokens": tokens}
|
| 632 |
|
| 633 |
-
"""## This section is for the Web App UI"""
|
| 634 |
|
| 635 |
@st.cache_resource(show_spinner=False)
|
| 636 |
def get_model_and_tok(model_name: str):
|
|
|
|
| 214 |
return G
|
| 215 |
|
| 216 |
def percolation_stats(G: nx.Graph) -> Dict[str, float]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
n = G.number_of_nodes()
|
| 218 |
if n == 0:
|
| 219 |
return dict(phi=0.0, num_clusters=0, chi=0.0, largest_component_size=0, component_sizes=[])
|
|
|
|
| 278 |
# ====== Model I/O (hidden states) =============================================================
|
| 279 |
@dataclass
|
| 280 |
class HiddenStatesBundle:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
hidden_layers: List[np.ndarray]
|
| 282 |
tokens: List[str]
|
| 283 |
|
| 284 |
|
| 285 |
def load_qwen(model_name: str, device: str, dtype: torch.dtype):
|
| 286 |
+
|
|
|
|
|
|
|
| 287 |
print(f"[Load] {model_name} on {device} ({dtype})")
|
| 288 |
config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
|
| 289 |
tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
|
|
| 296 |
|
| 297 |
@torch.no_grad()
|
| 298 |
def extract_hidden_states(model, tokenizer, text: str, max_length: int, device: str) -> HiddenStatesBundle:
|
| 299 |
+
|
|
|
|
|
|
|
|
|
|
| 300 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
| 301 |
out = model(**inputs)
|
| 302 |
# Tuple length = num_layers + 1 (embedding)
|
|
|
|
| 307 |
|
| 308 |
# ====== LoT-style anchors & features ==========================================================
|
| 309 |
def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray:
|
| 310 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
print(f"[Anchors] Fitting {K} global centroids on {len(all_states_sampled)} states ...")
|
| 312 |
kmeans = KMeans(n_clusters=K, n_init="auto", random_state=random_state)
|
| 313 |
kmeans.fit(all_states_sampled)
|
| 314 |
return kmeans.cluster_centers_ # (K, D)
|
| 315 |
|
| 316 |
def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 317 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
# Distances (N, K)
|
| 319 |
dists = pairwise_distances(H, anchors, metric="euclidean") # (N,K)
|
| 320 |
# Soft assignments
|
|
|
|
| 335 |
min_dist: float = 0.05,
|
| 336 |
metric: str = "cosine",
|
| 337 |
random_state: int = 42) -> umap.UMAP:
|
| 338 |
+
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
reducer = umap.UMAP(n_components=2, n_neighbors=n_neighbors, min_dist=min_dist,
|
| 341 |
metric=metric, random_state=random_state)
|
|
|
|
| 348 |
min_dist: float = 0.05,
|
| 349 |
metric: str = "cosine",
|
| 350 |
random_state: int = 42) -> np.ndarray:
|
| 351 |
+
|
|
|
|
|
|
|
|
|
|
| 352 |
reducer = umap.UMAP(n_components=3, n_neighbors=n_neighbors, min_dist=min_dist,
|
| 353 |
metric=metric, random_state=random_state)
|
| 354 |
return reducer.fit_transform(all_states)
|
| 355 |
|
|
|
|
| 356 |
|
| 357 |
# ====== Visualization ========================================================================
|
| 358 |
def plotly_3d_layers(xy_layers: List[np.ndarray],
|
|
|
|
| 498 |
)
|
| 499 |
return fig
|
| 500 |
|
|
|
|
| 501 |
|
| 502 |
def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False):
|
| 503 |
seed_everything(42)
|
|
|
|
| 596 |
|
| 597 |
return fig, {"percolation": percolation, "tokens": tokens}
|
| 598 |
|
|
|
|
| 599 |
|
| 600 |
@st.cache_resource(show_spinner=False)
|
| 601 |
def get_model_and_tok(model_name: str):
|