Jgray21 commited on
Commit
062f3e4
·
verified ·
1 Parent(s): bc94a3e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +6 -41
src/streamlit_app.py CHANGED
@@ -214,11 +214,6 @@ def build_threshold_graph(H: np.ndarray, top_pct: float = 0.05, use_cosine: bool
214
  return G
215
 
216
  def percolation_stats(G: nx.Graph) -> Dict[str, float]:
217
- """
218
- Compute percolation observables (φ, #clusters, χ) as in your notebook.
219
- φ : fraction of nodes in the Giant Connected Component (GCC)
220
- χ : mean size of components excluding GCC
221
- """
222
  n = G.number_of_nodes()
223
  if n == 0:
224
  return dict(phi=0.0, num_clusters=0, chi=0.0, largest_component_size=0, component_sizes=[])
@@ -283,19 +278,12 @@ def fit_global_anchors(pool: np.ndarray, K: int) -> np.ndarray:
283
  # ====== Model I/O (hidden states) =============================================================
284
  @dataclass
285
  class HiddenStatesBundle:
286
- """
287
- Encapsulates a single input's hidden states and metadata.
288
- hidden_layers: list of np.ndarray of shape (T, D), length = num_layers+1 (incl. embedding)
289
- tokens : list of token strings of length T
290
- """
291
  hidden_layers: List[np.ndarray]
292
  tokens: List[str]
293
 
294
 
295
  def load_qwen(model_name: str, device: str, dtype: torch.dtype):
296
- """
297
- Load Qwen with output_hidden_states=True. We use AutoTokenizer for broader compatibility.
298
- """
299
  print(f"[Load] {model_name} on {device} ({dtype})")
300
  config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
301
  tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
@@ -308,10 +296,7 @@ def load_qwen(model_name: str, device: str, dtype: torch.dtype):
308
 
309
  @torch.no_grad()
310
  def extract_hidden_states(model, tokenizer, text: str, max_length: int, device: str) -> HiddenStatesBundle:
311
- """
312
- Run a single forward pass to collect all hidden states (incl. embedding layer).
313
- Returns CPU numpy arrays to keep GPU memory low.
314
- """
315
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
316
  out = model(**inputs)
317
  # Tuple length = num_layers + 1 (embedding)
@@ -322,25 +307,14 @@ def extract_hidden_states(model, tokenizer, text: str, max_length: int, device:
322
 
323
  # ====== LoT-style anchors & features ==========================================================
324
  def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray:
325
- """
326
- Fit KMeans cluster centroids on a pooled set of states (from many layers/texts).
327
- These centroids are "anchors" (LoT-like choices) to build low-dim features:
328
- f(state) = [dist(state, anchor_j)]_{j=1..K}
329
- """
330
  print(f"[Anchors] Fitting {K} global centroids on {len(all_states_sampled)} states ...")
331
  kmeans = KMeans(n_clusters=K, n_init="auto", random_state=random_state)
332
  kmeans.fit(all_states_sampled)
333
  return kmeans.cluster_centers_ # (K, D)
334
 
335
  def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
336
- """
337
- For states H (N,D) and anchors A (K,D):
338
- - Compute Euclidean distances to each anchor → Dists (N,K)
339
- - Convert to soft probabilities with exp(-Dist/T), normalize row-wise → P (N,K)
340
- - Uncertainty = entropy(P) (cf. LoT Eq. (6))
341
- - Top-anchor argmin distance for "consistency"-style comparisons (cf. Eq. (5))
342
- Returns (Dists, P, entropy)
343
- """
344
  # Distances (N, K)
345
  dists = pairwise_distances(H, anchors, metric="euclidean") # (N,K)
346
  # Soft assignments
@@ -361,10 +335,7 @@ def fit_umap_2d(pool: np.ndarray,
361
  min_dist: float = 0.05,
362
  metric: str = "cosine",
363
  random_state: int = 42) -> umap.UMAP:
364
- """
365
- Fit UMAP once on a diverse pool across layers to preserve orientation.
366
- Later layers call .transform() to embed into the SAME 2D space → "MRI stack".
367
- """
368
 
369
  reducer = umap.UMAP(n_components=2, n_neighbors=n_neighbors, min_dist=min_dist,
370
  metric=metric, random_state=random_state)
@@ -377,15 +348,11 @@ def fit_umap_3d(all_states: np.ndarray,
377
  min_dist: float = 0.05,
378
  metric: str = "cosine",
379
  random_state: int = 42) -> np.ndarray:
380
- """
381
- Fit a global 3D UMAP embedding for all states at once (alternative to slice stack).
382
- Returns coords_3d (N,3) for the concatenated states passed in.
383
- """
384
  reducer = umap.UMAP(n_components=3, n_neighbors=n_neighbors, min_dist=min_dist,
385
  metric=metric, random_state=random_state)
386
  return reducer.fit_transform(all_states)
387
 
388
- """## Define Visualization Function"""
389
 
390
  # ====== Visualization ========================================================================
391
  def plotly_3d_layers(xy_layers: List[np.ndarray],
@@ -531,7 +498,6 @@ def plotly_3d_layers(xy_layers: List[np.ndarray],
531
  )
532
  return fig
533
 
534
- """## Building the pipeline"""
535
 
536
  def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False):
537
  seed_everything(42)
@@ -630,7 +596,6 @@ def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts
630
 
631
  return fig, {"percolation": percolation, "tokens": tokens}
632
 
633
- """## This section is for the Web App UI"""
634
 
635
  @st.cache_resource(show_spinner=False)
636
  def get_model_and_tok(model_name: str):
 
214
  return G
215
 
216
  def percolation_stats(G: nx.Graph) -> Dict[str, float]:
 
 
 
 
 
217
  n = G.number_of_nodes()
218
  if n == 0:
219
  return dict(phi=0.0, num_clusters=0, chi=0.0, largest_component_size=0, component_sizes=[])
 
278
  # ====== Model I/O (hidden states) =============================================================
279
  @dataclass
280
  class HiddenStatesBundle:
 
 
 
 
 
281
  hidden_layers: List[np.ndarray]
282
  tokens: List[str]
283
 
284
 
285
  def load_qwen(model_name: str, device: str, dtype: torch.dtype):
286
+
 
 
287
  print(f"[Load] {model_name} on {device} ({dtype})")
288
  config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
289
  tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
 
296
 
297
  @torch.no_grad()
298
  def extract_hidden_states(model, tokenizer, text: str, max_length: int, device: str) -> HiddenStatesBundle:
299
+
 
 
 
300
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
301
  out = model(**inputs)
302
  # Tuple length = num_layers + 1 (embedding)
 
307
 
308
  # ====== LoT-style anchors & features ==========================================================
309
  def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray:
310
+
 
 
 
 
311
  print(f"[Anchors] Fitting {K} global centroids on {len(all_states_sampled)} states ...")
312
  kmeans = KMeans(n_clusters=K, n_init="auto", random_state=random_state)
313
  kmeans.fit(all_states_sampled)
314
  return kmeans.cluster_centers_ # (K, D)
315
 
316
  def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
317
+
 
 
 
 
 
 
 
318
  # Distances (N, K)
319
  dists = pairwise_distances(H, anchors, metric="euclidean") # (N,K)
320
  # Soft assignments
 
335
  min_dist: float = 0.05,
336
  metric: str = "cosine",
337
  random_state: int = 42) -> umap.UMAP:
338
+
 
 
 
339
 
340
  reducer = umap.UMAP(n_components=2, n_neighbors=n_neighbors, min_dist=min_dist,
341
  metric=metric, random_state=random_state)
 
348
  min_dist: float = 0.05,
349
  metric: str = "cosine",
350
  random_state: int = 42) -> np.ndarray:
351
+
 
 
 
352
  reducer = umap.UMAP(n_components=3, n_neighbors=n_neighbors, min_dist=min_dist,
353
  metric=metric, random_state=random_state)
354
  return reducer.fit_transform(all_states)
355
 
 
356
 
357
  # ====== Visualization ========================================================================
358
  def plotly_3d_layers(xy_layers: List[np.ndarray],
 
498
  )
499
  return fig
500
 
 
501
 
502
  def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False):
503
  seed_everything(42)
 
596
 
597
  return fig, {"percolation": percolation, "tokens": tokens}
598
 
 
599
 
600
  @st.cache_resource(show_spinner=False)
601
  def get_model_and_tok(model_name: str):