Jgray21 commited on
Commit
bc94a3e
·
verified ·
1 Parent(s): 791dffa

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +353 -452
src/streamlit_app.py CHANGED
@@ -4,33 +4,27 @@ import json
4
  import warnings
5
  from dataclasses import dataclass, asdict
6
  from typing import Dict, List, Tuple, Optional
 
7
 
8
  import numpy as np
9
  import pandas as pd
10
-
11
  import torch
12
  from torch import nn
13
-
14
  import networkx as nx
15
  import streamlit as st
 
 
 
16
 
17
- # Transformers: Qwen tokenizer can be AutoTokenizer if Qwen2Tokenizer not present
18
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
19
-
20
- # Dimensionality reduction
21
  import umap
22
- from umap import UMAP
23
-
24
- # Neighbors & clustering
25
  from sklearn.neighbors import NearestNeighbors, KernelDensity
26
  from sklearn.cluster import KMeans, DBSCAN
27
- from sklearn.decomposition import PCA
28
  from sklearn.metrics import pairwise_distances
29
-
30
- # Plotly for interactive 3D
31
  import plotly.graph_objects as go
32
 
33
- import hashlib
34
 
35
  # Optional libs (use if present)
36
  try:
@@ -51,28 +45,20 @@ try:
51
  HAS_PYVISTA = True
52
  except Exception:
53
  HAS_PYVISTA = False
54
-
55
- from scipy.linalg import orthogonal_procrustes # For optional per-layer orientation alignment
56
-
57
- # ====== 1. Configuration =========================================================================
58
  @dataclass
59
  class Config:
60
  # Model
61
  model_name: str = "Qwen/Qwen1.5-1.8B"
62
- ### device: str = "cuda" if torch.cuda.is_available() else "cpu"
63
- ### dtype: torch.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
64
-
65
- # Tokenization / generation
66
- max_length: int = 64 # truncate inputs for speed/memory
67
 
68
  # Data
69
- corpus: List[str] = None # set below
70
- # If None, uses DEFAULT_CORPUS defined below
71
 
72
- # Graph building
73
- graph_mode: str = "threshold" # {"knn", "threshold"}
74
- knn_k: int = 8 # neighbors per token (used if graph_mode="knn")
75
- sim_threshold: float = 0.60 # used if graph_mode="threshold"
76
  use_cosine: bool = True
77
 
78
  # Anchors / LoT-style features (global)
@@ -84,104 +70,148 @@ class Config:
84
  n_clusters_kmeans: int = 6 # fallback for kmeans
85
  hdbscan_min_cluster_size: int = 4
86
 
87
- # DR / embeddings
88
  umap_n_neighbors: int = 30
89
  umap_min_dist: float = 0.05
90
- umap_metric: str = "cosine" # hidden states are directional → cosine works well
91
- use_global_3d_umap: bool = False # if True, compute a single 3D manifold for all states
92
-
93
- # Pooling for UMAP fit
94
  fit_pool_per_layer: int = 512 # number of states sampled per layer to fit UMAP
 
95
 
96
- # Volume grid (MRI view)
97
- grid_res: int = 128 # voxel resolution in x/y; z = num_layers
98
- kde_bandwidth: float = 0.15 # KDE bandwidth in manifold space (if using KDE)
99
- use_hist2d: bool = True # if True, use histogram2d instead of KDE for speed
100
 
101
  # Output
102
  out_dir: str = "qwen_mri3d_outputs"
103
  plotly_html: str = "qwen_layers_3d.html"
104
- volume_npz: str = "qwen_density_volume.npz" # saved if PyVista isn't available
105
- volume_screenshot: str = "qwen_volume.png" # if PyVista is available
106
-
107
- def validate(self):
108
- if self.graph_mode not in {"knn", "threshold"}:
109
- raise ValueError("graph_mode must be 'knn' or 'threshold'")
110
- if self.knn_k < 2:
111
- raise ValueError("knn_k must be >= 2")
112
- if self.anchor_k < 2:
113
- raise ValueError("anchor_k must be >= 2")
114
- if self.anchor_temp <= 0:
115
- raise ValueError("anchor_temp must be > 0")
116
-
117
-
118
 
119
  # Default corpus (small and diverse; adjust freely)
120
  DEFAULT_CORPUS = [
121
- "The cat sat on the mat and watched.",
122
- "Machine learning models process data using neural networks.",
123
- "Climate change affects ecosystems around the world.",
124
- "Quantum computers use superposition for parallel computation.",
125
- "The universe contains billions of galaxies.",
126
- "Artificial intelligence transforms how we work.",
127
- "DNA stores genetic information in cells.",
128
- "Ocean currents regulate Earth's climate system.",
129
- "Photosynthesis converts sunlight into chemical energy.",
130
- "Blockchain technology enables decentralized systems."
131
  ]
132
 
133
- # ====== 2. Utilities =============================================================================
 
 
 
 
 
134
  def seed_everything(seed: int = 42):
135
- """Determinism for reproducibility in layouts/UMAP/kmeans."""
136
  np.random.seed(seed)
137
  torch.manual_seed(seed)
138
 
139
-
140
  def cosine_similarity_matrix(X: np.ndarray) -> np.ndarray:
141
- """Compute pairwise cosine similarity for rows of X."""
142
- # X: (N, D)
143
  norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-8
144
  Xn = X / norms
145
  return Xn @ Xn.T
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- def build_knn_graph(coords: np.ndarray, k: int, metric: str = "cosine") -> nx.Graph:
 
 
 
 
149
  """
150
- Build an undirected kNN graph for the points in coords.
151
- coords: (N, D)
152
  """
153
- nbrs = NearestNeighbors(n_neighbors=min(k+1, len(coords)), metric=metric) # +1 to include self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  nbrs.fit(coords)
155
  distances, indices = nbrs.kneighbors(coords)
156
-
157
  G = nx.Graph()
158
  G.add_nodes_from(range(len(coords)))
159
- # Connect i to its top-k neighbors (skip index 0 which is itself)
160
  for i in range(len(coords)):
161
- for j in indices[i, 1:]: # skip self
162
  G.add_edge(int(i), int(j))
163
  return G
164
 
165
-
166
- def build_threshold_graph(H: np.ndarray, threshold: float, use_cosine: bool = True) -> nx.Graph:
167
- """
168
- Build graph by thresholding pairwise similarities in the original hidden-state space.
169
- H: (N, D) hidden states for a single layer
170
- """
171
  if use_cosine:
172
  S = cosine_similarity_matrix(H)
173
  else:
174
- S = H @ H.T # dot product
175
 
176
  N = S.shape[0]
 
 
 
 
 
 
177
  G = nx.Graph()
178
  G.add_nodes_from(range(N))
179
- for i in range(N):
180
- for j in range(i + 1, N):
181
- if S[i, j] > threshold:
182
- G.add_edge(i, j, weight=float(S[i, j]))
183
- return G
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  def percolation_stats(G: nx.Graph) -> Dict[str, float]:
187
  """
@@ -210,94 +240,47 @@ def percolation_stats(G: nx.Graph) -> Dict[str, float]:
210
  largest_component_size=largest,
211
  component_sizes=sorted(sizes, reverse=True))
212
 
 
 
 
 
 
 
 
 
 
 
 
213
 
 
214
  def leiden_communities(G: nx.Graph) -> np.ndarray:
215
- """
216
- Community detection using Leiden (igraph), if available.
217
- Returns an array of cluster ids for nodes 0..N-1.
218
- """
219
- if not HAS_IGRAPH_LEIDEN:
220
- raise RuntimeError("igraph+leidenalg not available")
221
-
222
- # Convert nx → igraph
223
  mapping = {n: i for i, n in enumerate(G.nodes())}
224
  edges = [(mapping[u], mapping[v]) for u, v in G.edges()]
225
  ig_g = ig.Graph(n=len(mapping), edges=edges, directed=False)
226
- part = la.find_partition(ig_g, la.RBConfigurationVertexPartition) # robust default
227
  labels = np.zeros(len(mapping), dtype=int)
228
  for cid, comm in enumerate(part):
229
- for node in comm:
230
- labels[node] = cid
231
  return labels
232
 
 
 
 
 
 
 
 
 
 
233
 
234
- def cluster_layer(features: np.ndarray,
235
- G: Optional[nx.Graph],
236
- method: str,
237
- n_clusters_kmeans: int = 6,
238
- hdbscan_min_cluster_size: int = 4) -> np.ndarray:
239
- """
240
- Cluster layer states to get cluster labels.
241
- - If Leiden: requires G (graph) and igraph/leidenalg
242
- - If HDBSCAN: density-based clustering in feature space
243
- - If DBSCAN: fallback density-based (scikit-learn)
244
- - If KMeans: fallback centroid clustering
245
- """
246
- method = method.lower()
247
- N = len(features)
248
-
249
- if method == "auto":
250
- # Prefer Leiden (graph) → HDBSCAN → KMeans
251
- if HAS_IGRAPH_LEIDEN and G is not None and G.number_of_edges() > 0:
252
- return leiden_communities(G)
253
- elif HAS_HDBSCAN and N >= 5:
254
- clusterer = hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size,
255
- metric='euclidean')
256
- labels = clusterer.fit_predict(features)
257
- # HDBSCAN: -1 = noise. Keep as its own "noise" cluster id or remap
258
- return labels
259
- else:
260
- km = KMeans(n_clusters=min(n_clusters_kmeans, max(2, N // 3)),
261
- n_init="auto", random_state=42)
262
- return km.fit_predict(features)
263
-
264
- if method == "leiden":
265
- if G is None or not HAS_IGRAPH_LEIDEN:
266
- raise RuntimeError("Leiden requires a graph and igraph+leidenalg.")
267
- return leiden_communities(G)
268
-
269
- if method == "hdbscan":
270
- if not HAS_HDBSCAN:
271
- raise RuntimeError("hdbscan not installed")
272
- clusterer = hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size, metric='euclidean')
273
- return clusterer.fit_predict(features)
274
-
275
- if method == "dbscan":
276
- db = DBSCAN(eps=0.5, min_samples=4, metric='euclidean')
277
- return db.fit_predict(features)
278
-
279
- if method == "kmeans":
280
- km = KMeans(n_clusters=min(n_clusters_kmeans, max(2, N // 3)),
281
- n_init="auto", random_state=42)
282
- return km.fit_predict(features)
283
-
284
- raise ValueError(f"Unknown cluster method: {method}")
285
-
286
-
287
- def orthogonal_align(A_ref: np.ndarray, B: np.ndarray) -> np.ndarray:
288
- """
289
- Align B to A_ref by an orthogonal rotation (Procrustes),
290
- preserving geometry but removing arbitrary orientation flips.
291
- """
292
- R, _ = orthogonal_procrustes(B - B.mean(0), A_ref - A_ref.mean(0))
293
- return (B - B.mean(0)) @ R + A_ref.mean(0)
294
-
295
 
296
- def entropy_from_probs(p: np.ndarray, eps: float = 1e-12) -> np.ndarray:
297
- """Shannon entropy for each row; p is (N, K) with rows summing ~1."""
298
- return -np.sum(p * np.log(p + eps), axis=1)
299
 
300
- # ====== 3. Model I/O (hidden states) =============================================================
301
  @dataclass
302
  class HiddenStatesBundle:
303
  """
@@ -336,7 +319,8 @@ def extract_hidden_states(model, tokenizer, text: str, max_length: int, device:
336
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
337
  return HiddenStatesBundle(hidden_layers=hs, tokens=tokens)
338
 
339
- # ====== 4. LoT-style anchors & features ==========================================================
 
340
  def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray:
341
  """
342
  Fit KMeans cluster centroids on a pooled set of states (from many layers/texts).
@@ -348,7 +332,6 @@ def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int
348
  kmeans.fit(all_states_sampled)
349
  return kmeans.cluster_centers_ # (K, D)
350
 
351
-
352
  def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
353
  """
354
  For states H (N,D) and anchors A (K,D):
@@ -367,10 +350,12 @@ def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0
367
  P = np.exp(logits)
368
  P /= P.sum(axis=1, keepdims=True) + 1e-12
369
  # Uncertainty (entropy)
370
- H_unc = entropy_from_probs(P)
 
371
  return dists, P, H_unc
372
 
373
- # ====== 5. Dimensionality reduction / embeddings ================================================
 
374
  def fit_umap_2d(pool: np.ndarray,
375
  n_neighbors: int = 30,
376
  min_dist: float = 0.05,
@@ -386,63 +371,6 @@ def fit_umap_2d(pool: np.ndarray,
386
  reducer.fit(pool)
387
  return reducer
388
 
389
- def _corpus_fingerprint(texts, max_items=5, max_chars=4000) -> str:
390
- """Stable key so cache invalidates if DEFAULT_CORPUS changes."""
391
- joined = "\n".join(texts[:max_items])
392
- joined = joined[:max_chars]
393
- return hashlib.sha256(joined.encode("utf-8")).hexdigest()
394
-
395
- @st.cache_data(show_spinner=False)
396
- def get_pool_artifacts(
397
- model_name: str,
398
- max_length: int,
399
- anchor_k: int,
400
- anchor_temp: float, # not strictly needed for fitting anchors, but included if you want cache keys aligned
401
- umap_n_neighbors: int,
402
- umap_min_dist: float,
403
- umap_metric: str,
404
- fit_pool_per_layer: int,
405
- corpus_hash: str,
406
- ):
407
- """
408
- Cached: build pooled hidden states on DEFAULT_CORPUS, fit anchors and a UMAP reducer once.
409
- Returns:
410
- anchors: (K, D) np.ndarray
411
- reducer2d: fitted UMAP reducer object (must be pickleable; umap-learn's UMAP is)
412
- """
413
- # Use cached model loader (resource cache)
414
- model, tok, device, dtype = get_model_and_tok(model_name)
415
-
416
- texts = DEFAULT_CORPUS # pooled set for stability
417
-
418
- pool_states = []
419
- for t in texts[: min(5, len(texts))]:
420
- b = extract_hidden_states(model, tok, t, max_length, device)
421
- for H in b.hidden_layers:
422
- T = len(H)
423
- take = min(fit_pool_per_layer, T)
424
- if take <= 0:
425
- continue
426
- idx = np.random.choice(T, size=take, replace=False)
427
- pool_states.append(H[idx])
428
-
429
- if not pool_states:
430
- # fallback: this should rarely happen
431
- raise RuntimeError("Pool construction produced no states.")
432
-
433
- pool_states = np.vstack(pool_states)
434
-
435
- anchors = fit_global_anchors(pool_states, anchor_k)
436
-
437
- reducer2d = fit_umap_2d(
438
- pool_states,
439
- n_neighbors=umap_n_neighbors,
440
- min_dist=umap_min_dist,
441
- metric=umap_metric,
442
- )
443
-
444
- return anchors, reducer2d
445
-
446
 
447
  def fit_umap_3d(all_states: np.ndarray,
448
  n_neighbors: int = 30,
@@ -457,313 +385,283 @@ def fit_umap_3d(all_states: np.ndarray,
457
  metric=metric, random_state=random_state)
458
  return reducer.fit_transform(all_states)
459
 
460
- # ====== 6. Volume construction (MRI) ============================================================
461
- def stack_density_volume(xy_by_layer: List[np.ndarray],
462
- grid_res: int,
463
- use_hist2d: bool = True,
464
- kde_bandwidth: float = 0.15) -> np.ndarray:
465
- """
466
- Construct a 3D volume by estimating 2D density on the (x,y) manifold per layer (slice).
467
- - If use_hist2d: fast uniform binning into grid_res x grid_res
468
- - Else: KDE (slower but smoother)
469
- Returns volume of shape (grid_res, grid_res, L) where L = #layers.
470
- """
471
- L = len(xy_by_layer)
472
- vol = np.zeros((grid_res, grid_res, L), dtype=np.float32)
473
-
474
- # Determine global bounds across layers to keep axes consistent
475
- all_xy = np.vstack([xy for xy in xy_by_layer if len(xy) > 0]) if L > 0 else np.zeros((0, 2))
476
- if len(all_xy) == 0:
477
- return vol
478
- x_min, y_min = all_xy.min(axis=0)
479
- x_max, y_max = all_xy.max(axis=0)
480
- # Slight padding
481
- pad = 1e-6
482
- x_edges = np.linspace(x_min - pad, x_max + pad, grid_res + 1)
483
- y_edges = np.linspace(y_min - pad, y_max + pad, grid_res + 1)
484
-
485
- for l, XY in enumerate(xy_by_layer):
486
- if len(XY) == 0:
487
- continue
488
 
489
- if use_hist2d:
490
- H, _, _ = np.histogram2d(XY[:, 0], XY[:, 1], bins=[x_edges, y_edges], density=False)
491
- vol[:, :, l] = H.T # histogram2d returns [x_bins, y_bins] → transpose to align
492
- else:
493
- kde = KernelDensity(bandwidth=kde_bandwidth, kernel="gaussian")
494
- kde.fit(XY)
495
- # Evaluate KDE on grid centers
496
- xs = 0.5 * (x_edges[:-1] + x_edges[1:])
497
- ys = 0.5 * (y_edges[:-1] + y_edges[1:])
498
- xx, yy = np.meshgrid(xs, ys, indexing='xy')
499
- grid_points = np.column_stack([xx.ravel(), yy.ravel()])
500
- log_dens = kde.score_samples(grid_points)
501
- dens = np.exp(log_dens).reshape(grid_res, grid_res)
502
- vol[:, :, l] = dens
503
-
504
- # Normalize volume to [0,1] for rendering convenience
505
- if vol.max() > 0:
506
- vol = vol / vol.max()
507
- return vol
508
-
509
-
510
- def render_volume_with_pyvista(volume: np.ndarray,
511
- out_png: str,
512
- opacity="sigmoid") -> None:
513
- """
514
- Visualize the 3D volume using PyVista/VTK (if installed); save a screenshot.
515
- """
516
- if not HAS_PYVISTA:
517
- raise RuntimeError("PyVista is not installed; cannot render volume.")
518
- pl = pv.Plotter()
519
- # Wrap NumPy array as a VTK image data; PyVista expects z as the 3rd axis
520
- vol_vtk = pv.wrap(volume)
521
- pl.add_volume(vol_vtk, opacity=opacity, shade=True)
522
- pl.show(screenshot=out_png) # headless environments will still save a screenshot (if offscreen support)
523
-
524
- # ====== 7. 3D Plotly visualization ==============================================================
525
  def plotly_3d_layers(xy_layers: List[np.ndarray],
526
  layer_tokens: List[List[str]],
527
  layer_cluster_labels: List[np.ndarray],
 
528
  layer_uncertainty: List[np.ndarray],
529
  layer_graphs: List[nx.Graph],
530
- connect_token_trajectories: bool = True,
531
- title: str = "Qwen: 3D Cluster Formation (UMAP2D + Layer as Z)") -> go.Figure:
532
- """
533
- Build an interactive 3D Plotly figure:
534
- - Nodes per layer at (x, y, z=layer)
535
- - Edge segments (kNN or threshold graph) per layer
536
- - Trajectory lines: connect same token index across consecutive layers (optional)
537
- - Color nodes by cluster label; hover shows token & uncertainty
538
- """
539
  fig_data = []
540
 
541
- # Build a color per layer node trace
542
- for l, (xy, tokens, labels, unc, G) in enumerate(zip(xy_layers, layer_tokens, layer_cluster_labels, layer_uncertainty, layer_graphs)):
543
- if len(xy) == 0:
544
- continue
 
 
 
 
 
 
 
545
  x, y = xy[:, 0], xy[:, 1]
546
  z = np.full_like(x, l, dtype=float)
547
 
548
- # --- Nodes
549
- node_text = [f"layer={l} | idx={i}<br>token={tokens[i]}<br>cluster={int(labels[i])}<br>uncertainty={unc[i]:.3f}"
550
- for i in range(len(tokens))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  node_trace = go.Scatter3d(
552
  x=x, y=y, z=z,
553
  mode='markers',
554
  name=f"Layer {l}",
 
555
  marker=dict(
556
- size=4,
557
- opacity=0.7,
558
- color=labels, # cluster ID → color scale
559
- colorscale='Viridis',
560
- showscale=(l == 0) # show scale once
 
561
  ),
562
  text=node_text,
563
  hovertemplate="%{text}<extra></extra>"
564
  )
565
  fig_data.append(node_trace)
566
 
567
- # --- Intra-layer edges (kNN or threshold)
568
  if G is not None and G.number_of_edges() > 0:
569
  edge_x, edge_y, edge_z = [], [], []
570
  for u, v in G.edges():
571
  edge_x += [x[u], x[v], None]
572
  edge_y += [y[u], y[v], None]
573
  edge_z += [z[u], z[v], None]
 
574
  edge_trace = go.Scatter3d(
575
  x=edge_x, y=edge_y, z=edge_z,
576
  mode='lines',
577
- line=dict(width=1),
578
- opacity=0.30,
579
- name=f"Edges L{l}"
 
580
  )
581
  fig_data.append(edge_trace)
582
 
583
- # --- Trajectories: connect same token index across layers
584
- if connect_token_trajectories:
585
- # Only meaningful if tokenization length T is constant across layers (it is)
586
- # We'll draw faint polylines for each position i across l=0..L-1
587
- L = len(xy_layers)
588
- if L > 1:
589
- T = min(len(xy_layers[l]) for l in range(L))
590
- for i in range(T):
591
- xs = [xy_layers[l][i, 0] for l in range(L)]
592
- ys = [xy_layers[l][i, 1] for l in range(L)]
593
- zs = list(range(L))
594
- traj = go.Scatter3d(
595
- x=xs, y=ys, z=zs,
596
- mode='lines',
597
- line=dict(width=1),
598
- opacity=0.15,
599
- name=f"traj_{i}",
600
- hoverinfo='skip'
 
 
 
 
 
 
 
 
 
 
 
 
601
  )
602
- fig_data.append(traj)
603
 
604
  fig = go.Figure(data=fig_data)
605
  fig.update_layout(
606
- title=title,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  scene=dict(
608
  xaxis_title="UMAP X",
609
  yaxis_title="UMAP Y",
610
- zaxis_title="Layer (depth)"
 
611
  ),
612
  height=900,
613
- showlegend=False
614
  )
615
  return fig
616
 
617
- # ====== 8. Orchestration ========================================================================
618
- def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False):
 
619
  seed_everything(42)
620
 
621
- # 8.2 Collect hidden states for one representative text (detailed viz) + for pool
622
- # You can extend to many texts; we keep a single text for clarity & speed.
623
- texts = cfg.corpus or DEFAULT_CORPUS
624
- #print(f"[Input] Example text: {main_text!r}")
625
 
626
- # Hidden states for main text
627
  main_bundle = extract_hidden_states(model, tok, main_text, cfg.max_length, device)
628
- layers_np: List[np.ndarray] = main_bundle.hidden_layers # list of (T,D), length L_all = num_layers+1
629
- tokens = main_bundle.tokens # list of length T
630
-
631
- # Cached pool artifacts (anchors + fitted UMAP reducer)
632
- corpus_hash = _corpus_fingerprint(texts) # texts is cfg.corpus or DEFAULT_CORPUS
633
-
634
- anchors, reducer2d = get_pool_artifacts(
635
- model_name=cfg.model_name,
636
- max_length=cfg.max_length,
637
- anchor_k=cfg.anchor_k,
638
- anchor_temp=cfg.anchor_temp,
639
- umap_n_neighbors=cfg.umap_n_neighbors,
640
- umap_min_dist=cfg.umap_min_dist,
641
- umap_metric=cfg.umap_metric,
642
- fit_pool_per_layer=cfg.fit_pool_per_layer,
643
- corpus_hash=corpus_hash,
644
- )
645
-
646
  L_all = len(layers_np)
647
- #print(f"[Hidden] Layers (incl. embedding): {L_all}, Tokens: {len(tokens)}")
648
 
649
- """
650
- # 8.3 Build a pool of states (across a few texts & layers) to fit anchors + UMAP
651
- pool_states = []
652
- # Sample across first few texts to improve diversity (lightweight)
653
- for t in texts[: min(5, len(texts))]:
654
- b = extract_hidden_states(model, tok, t, cfg.max_length, device)
655
- # Take a subset from each layer to limit pool size
656
- for H in b.hidden_layers:
657
- T = len(H)
658
- take = min(cfg.fit_pool_per_layer, T)
659
- idx = np.random.choice(T, size=take, replace=False)
660
- pool_states.append(H[idx])
661
- pool_states = np.vstack(pool_states) if len(pool_states) else layers_np[-1]
662
- #print(f"[Pool] Pooled states for anchors/UMAP: {pool_states.shape}")
663
-
664
- # 8.4 Fit global anchors (LoT-style features)
665
- anchors = fit_global_anchors(pool_states, cfg.anchor_k)
666
- # Save anchors for reproducibility
667
- """
668
-
669
- # 8.5 Build per-layer features for main text (LoT-style distances & uncertainty)
670
- layer_features = [] # list of (T,K)
671
- layer_uncertainties = [] # list of (T,)
672
- layer_top_anchor = [] # list of (T,) argmin-id
673
-
674
- for l, H in enumerate(layers_np):
675
- dists, P, H_unc = anchor_features(H, anchors, cfg.anchor_temp)
676
- layer_features.append(dists) # N x K distances (lower = closer)
677
- layer_uncertainties.append(H_unc) # N
678
- layer_top_anchor.append(np.argmin(dists, axis=1)) # closest anchor id per token
679
 
680
- # 8.6 Consistency metric (LoT Eq. (5)): does layer's top anchor match final layer's?
681
- final_top = layer_top_anchor[-1]
682
- layer_consistency = []
683
- for l in range(L_all):
684
- cons = (layer_top_anchor[l] == final_top).astype(np.int32) # 1 if matches, 0 otherwise
685
- layer_consistency.append(cons)
686
 
687
- # 8.7 Build per-layer graphs (kNN by default) on FEATURE space for stability
 
 
688
  layer_graphs = []
 
 
 
689
  for l in range(L_all):
690
- feats = layer_features[l]
 
 
 
 
 
 
 
691
  if cfg.graph_mode == "knn":
692
- G = build_knn_graph(feats, cfg.knn_k, metric="euclidean") # kNN in feature space
693
  else:
694
- # Threshold graph in original hidden space (as in your notebook)
695
- G = build_threshold_graph(layers_np[l], cfg.sim_threshold, use_cosine=cfg.use_cosine)
696
  layer_graphs.append(G)
697
 
698
- # 8.8 Cluster per layer
699
- layer_cluster_labels = []
700
- for l in range(L_all):
701
- feats = layer_features[l]
702
- labels = cluster_layer(
703
- feats,
704
- layer_graphs[l],
705
- method=cfg.cluster_method,
706
- n_clusters_kmeans=cfg.n_clusters_kmeans,
707
- hdbscan_min_cluster_size=cfg.hdbscan_min_cluster_size
708
- )
709
  layer_cluster_labels.append(labels)
710
 
711
- # 8.9 Percolation statistics (φ, #clusters, χ) per layer (as in your notebook)
712
- percolation = []
713
- for l in range(L_all):
714
- stats = percolation_stats(layer_graphs[l])
715
- percolation.append(stats)
 
 
 
716
 
 
 
 
 
717
 
718
- # 8.10 Common 2D manifold via UMAP (fit-once on the pool), then transform each layer
719
- """reducer2d = fit_umap_2d(pool_states,
720
- n_neighbors=cfg.umap_n_neighbors,
721
- min_dist=cfg.umap_min_dist,
722
- metric=cfg.umap_metric)"""
723
- xy_by_layer = [reducer2d.transform(layers_np[l]) for l in range(L_all)]
724
 
725
- # OPTIONAL: orthogonal alignment across layers (helps if UMAP.transform still drifts)
726
- # for l in range(1, L_all):
727
- # xy_by_layer[l] = orthogonal_align(xy_by_layer[l-1], xy_by_layer[l])
728
 
729
- # 8.11 Plotly 3D point+graph view: X,Y from UMAP; Z = layer index
730
  fig = plotly_3d_layers(
731
  xy_layers=xy_by_layer,
732
- layer_tokens=[tokens for _ in range(L_all)],
733
  layer_cluster_labels=layer_cluster_labels,
 
734
  layer_uncertainty=layer_uncertainties,
735
  layer_graphs=layer_graphs,
736
- connect_token_trajectories=True,
737
- title="Qwen: 3D Cluster Formation (UMAP2D + Layer as Z, LoT metrics on hover)"
 
738
  )
739
 
 
740
  if save_artifacts:
741
- os.makedirs(cfg.out_dir, exist_ok=True)
742
- html_path = os.path.join(cfg.out_dir, cfg.plotly_html)
743
- fig.write_html(html_path)
744
- # Save percolation series
745
- with open(os.path.join(cfg.out_dir, "percolation_stats.json"), "w") as f:
746
- json.dump(percolation, f, indent=2)
747
- np.save(os.path.join(cfg.out_dir, "anchors.npy"), anchors)
748
- #print(f"[Percolation] Saved per-layer stats → percolation_stats.json")
749
- #print(f"[Plotly] 3D HTML saved → {html_path}")
 
750
 
751
  return fig, {"percolation": percolation, "tokens": tokens}
752
 
 
 
753
  @st.cache_resource(show_spinner=False)
754
  def get_model_and_tok(model_name: str):
755
  device = "cuda" if torch.cuda.is_available() else "cpu"
756
  dtype = torch.float16 if device == "cuda" else torch.float32
757
- model, tok = load_qwen(model_name, device, dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
  return model, tok, device, dtype
759
 
760
  def main():
761
- st.set_page_config(page_title="Layer Explorer", layout="wide")
762
- st.title("3D Token Embedding Explorer (Live Hidden States)")
763
 
764
  with st.sidebar:
765
  st.header("Model / Input")
766
- model_name = st.selectbox("Model", ["Qwen/Qwen1.5-0.5B", "Qwen/Qwen1.5-1.8B", "Qwen/Qwen1.5-4B"], index=1)
767
  max_length = st.slider("Max tokens", 16, 256, 64, step=16)
768
 
769
  st.header("Graph")
@@ -787,17 +685,20 @@ def main():
787
  st.header("Outputs")
788
  save_artifacts = st.checkbox("Save artifacts to disk (HTML/CSV/NPZ)", value=False)
789
 
790
- prompt_col, run_col = st.columns([4, 1])
791
- with prompt_col:
792
- main_text = st.text_area(
793
- "Text to visualize (hidden states computed on this text)",
794
- value="Explain in one sentence what a transformer attention layer does.",
795
- height=140
796
- )
797
- with run_col:
798
- st.write("")
799
- st.write("")
800
- run_btn = st.button("Run", type="primary")
 
 
 
801
 
802
  cfg = Config(
803
  model_name=model_name,
 
4
  import warnings
5
  from dataclasses import dataclass, asdict
6
  from typing import Dict, List, Tuple, Optional
7
+ from tabulate import tabulate
8
 
9
  import numpy as np
10
  import pandas as pd
 
11
  import torch
12
  from torch import nn
 
13
  import networkx as nx
14
  import streamlit as st
15
+ import spacy
16
+ import spacy.cli
17
+ spacy.cli.download("en_core_web_sm")
18
 
 
19
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
 
 
20
  import umap
 
 
 
21
  from sklearn.neighbors import NearestNeighbors, KernelDensity
22
  from sklearn.cluster import KMeans, DBSCAN
 
23
  from sklearn.metrics import pairwise_distances
24
+ from scipy.spatial import procrustes
25
+ from scipy.linalg import orthogonal_procrustes
26
  import plotly.graph_objects as go
27
 
 
28
 
29
  # Optional libs (use if present)
30
  try:
 
45
  HAS_PYVISTA = True
46
  except Exception:
47
  HAS_PYVISTA = False
48
+ # ====== Configuration =========================================================================
 
 
 
49
  @dataclass
50
  class Config:
51
  # Model
52
  model_name: str = "Qwen/Qwen1.5-1.8B"
53
+ max_length: int = 64
 
 
 
 
54
 
55
  # Data
56
+ corpus: List[str] = None
 
57
 
58
+ # Graph & Clustering
59
+ graph_mode: str = "threshold"
60
+ knn_k: int = 8
61
+ sim_threshold: float = 0.05 # Percentile of edges shown 0.05 = Show top 5% of edges
62
  use_cosine: bool = True
63
 
64
  # Anchors / LoT-style features (global)
 
70
  n_clusters_kmeans: int = 6 # fallback for kmeans
71
  hdbscan_min_cluster_size: int = 4
72
 
73
+ # UMAP & alignment
74
  umap_n_neighbors: int = 30
75
  umap_min_dist: float = 0.05
76
+ umap_metric: str = "cosine"
 
 
 
77
  fit_pool_per_layer: int = 512 # number of states sampled per layer to fit UMAP
78
+ align_layers: bool = True # aligning procrustes to layers
79
 
80
+ # Visualization
81
+ color_by: str = "pos" # "cluster" or "pos" (Part of Speech)
 
 
82
 
83
  # Output
84
  out_dir: str = "qwen_mri3d_outputs"
85
  plotly_html: str = "qwen_layers_3d.html"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # Default corpus (small and diverse; adjust freely)
88
  DEFAULT_CORPUS = [
89
+ "Is a Universal Basic Income (UBI) a viable solution to poverty, or does it simply discourage people from working?",
90
+ "Explain the arguments for and against the independence of Taiwan from the perspective of both the US and China.",
91
+ "What are the ethical arguments surrounding the use of CRISPR technology to edit human embryos for non-medical enhancements?",
92
+ "Analyze the effectiveness of strict lockdowns versus herd immunity strategies during the COVID-19 pandemic.",
93
+ "Why is nuclear energy controversial despite being a low-carbon power source? Present both the safety concerns and the environmental benefits.",
94
+ "Does the existence of evil in the world disprove the existence of a benevolent God? Summarize the philosophical debate.",
95
+ "Summarize the main arguments used by gun rights advocates against stricter background checks in the United States.",
96
+ "Should autonomous weapons systems (killer robots) be banned internationally, even if they could reduce soldier casualties?",
97
+ "Was the dropping of the atomic bombs on Hiroshima and Nagasaki militarily necessary to end World War II?",
98
+ "What are the competing arguments regarding transgender women participating in biological women's sports categories?"
99
  ]
100
 
101
+ #Select from 4 different models
102
+ MODELS = ["Qwen/Qwen1.5-0.5B", "deepseek-ai/deepseek-coder-1.3b-instruct", "openai-community/gpt2", "prem-research/MiniGuard-v0.1"]
103
+
104
+ """## Defining Utility Functions"""
105
+
106
+ # ====== Utilities =========================================================================
107
  def seed_everything(seed: int = 42):
 
108
  np.random.seed(seed)
109
  torch.manual_seed(seed)
110
 
 
111
  def cosine_similarity_matrix(X: np.ndarray) -> np.ndarray:
 
 
112
  norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-8
113
  Xn = X / norms
114
  return Xn @ Xn.T
115
 
116
+ def orthogonal_align(A_ref: np.ndarray, B: np.ndarray) -> np.ndarray:
117
+ """
118
+ Align B to A_ref using Procrustes analysis (rotation/reflection only).
119
+ Preserves local geometry of B, but aligns global orientation to A.
120
+ """
121
+ # Center both
122
+ mu_a = A_ref.mean(0)
123
+ mu_b = B.mean(0)
124
+ A0 = A_ref - mu_a
125
+ B0 = B - mu_b
126
+
127
+ # Solve for Rotation R that minimizes ||A0 - B0 @ R||
128
+ # M = B0.T @ A0
129
+ # U, S, Vt = svd(M)
130
+ # R = U @ Vt
131
+ R, _ = orthogonal_procrustes(B0, A0)
132
 
133
+ # B_aligned = (B - mu_b) @ R + mu_a
134
+ # We essentially rotate B to match A's orientation, then shift to A's center
135
+ return B0 @ R + mu_a
136
+
137
+ def get_pos_tags(text: str, tokenizer, tokens: List[str]) -> List[str]:
138
  """
139
+ Map LLM tokens to Spacy POS tags.
140
+ Heuristic: Reconstruct text, run Spacy, align based on char overlap.
141
  """
142
+ try:
143
+ nlp = spacy.load("en_core_web_sm")
144
+ except:
145
+ # Fallback if model not downloaded
146
+ return ["UNK"] * len(tokens)
147
+
148
+ doc = nlp(text)
149
+
150
+ # This is a simplified mapping. Real alignment is complex due to subwords.
151
+ # We will approximate: Find which word the subword belongs to.
152
+ pos_tags = []
153
+
154
+ # Re-build offsets for tokens (simplified)
155
+ # Ideally, we use tokenizer(return_offsets_mapping=True)
156
+ # Here we will just iterate and approximate for the demo.
157
+
158
+ # Fast approximation: tag the token string itself
159
+ # (Not perfect for subwords like "ing", but visually useful)
160
+ for t_str in tokens:
161
+ clean_t = t_str.replace("Ġ", "").replace("▁", "").strip()
162
+ if not clean_t:
163
+ pos_tags.append("SYM") # likely special char
164
+ continue
165
+
166
+ # Tag the single token fragment
167
+ sub_doc = nlp(clean_t)
168
+ if len(sub_doc) > 0:
169
+ pos_tags.append(sub_doc[0].pos_)
170
+ else:
171
+ pos_tags.append("UNK")
172
+
173
+ return pos_tags
174
+
175
+ def build_knn_graph(coords: np.ndarray, k: int, metric: str = "cosine") -> nx.Graph:
176
+ nbrs = NearestNeighbors(n_neighbors=min(k+1, len(coords)), metric=metric)
177
  nbrs.fit(coords)
178
  distances, indices = nbrs.kneighbors(coords)
 
179
  G = nx.Graph()
180
  G.add_nodes_from(range(len(coords)))
 
181
  for i in range(len(coords)):
182
+ for j in indices[i, 1:]:
183
  G.add_edge(int(i), int(j))
184
  return G
185
 
186
+ def build_threshold_graph(H: np.ndarray, top_pct: float = 0.05, use_cosine: bool = True, include_ties: bool = True,) -> nx.Graph:
 
 
 
 
 
187
  if use_cosine:
188
  S = cosine_similarity_matrix(H)
189
  else:
190
+ S = H @ H.T
191
 
192
  N = S.shape[0]
193
+ iu = np.triu_indices(N, k=1)
194
+ vals = S[iu]
195
+
196
+ # threshold at (1 - top_pct) quantile
197
+ q = 1.0 - top_pct
198
+ thr = float(np.quantile(vals, q))
199
  G = nx.Graph()
200
  G.add_nodes_from(range(N))
 
 
 
 
 
201
 
202
+ if include_ties:
203
+ mask = vals >= thr
204
+ else:
205
+ # strictly greater than threshold reduces tie-inflation
206
+ mask = vals > thr
207
+
208
+ rows = iu[0][mask]
209
+ cols = iu[1][mask]
210
+ wts = vals[mask]
211
+
212
+ for r, c, w in zip(rows, cols, wts):
213
+ G.add_edge(int(r), int(c), weight=float(w))
214
+ return G
215
 
216
  def percolation_stats(G: nx.Graph) -> Dict[str, float]:
217
  """
 
240
  largest_component_size=largest,
241
  component_sizes=sorted(sizes, reverse=True))
242
 
243
+ def cluster_layer(features: np.ndarray, G: Optional[nx.Graph], method: str,
244
+ n_clusters_kmeans: int=6, hdbscan_min_cluster_size: int=4) -> np.ndarray:
245
+ # (Same as original)
246
+ method = method.lower()
247
+ N = len(features)
248
+ if method == "auto":
249
+ if HAS_IGRAPH_LEIDEN and G and G.number_of_edges() > 0: return leiden_communities(G)
250
+ elif HAS_HDBSCAN: return hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size).fit_predict(features)
251
+ else: return KMeans(n_clusters=min(n_clusters_kmeans, N), n_init="auto").fit_predict(features)
252
+ # ... (rest of method dispatch unchanged)
253
+ return KMeans(n_clusters=min(n_clusters_kmeans, N), n_init="auto").fit_predict(features)
254
 
255
+ # Helper for Leiden (from original)
256
  def leiden_communities(G: nx.Graph) -> np.ndarray:
257
+ if not HAS_IGRAPH_LEIDEN: raise RuntimeError("Missing igraph")
 
 
 
 
 
 
 
258
  mapping = {n: i for i, n in enumerate(G.nodes())}
259
  edges = [(mapping[u], mapping[v]) for u, v in G.edges()]
260
  ig_g = ig.Graph(n=len(mapping), edges=edges, directed=False)
261
+ part = la.find_partition(ig_g, la.RBConfigurationVertexPartition)
262
  labels = np.zeros(len(mapping), dtype=int)
263
  for cid, comm in enumerate(part):
264
+ for node in comm: labels[node] = cid
 
265
  return labels
266
 
267
+ def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0):
268
+ dists = pairwise_distances(H, anchors, metric="euclidean")
269
+ logits = -dists / max(temperature, 1e-6)
270
+ logits = logits - logits.max(axis=1, keepdims=True)
271
+ P = np.exp(logits)
272
+ P /= P.sum(axis=1, keepdims=True) + 1e-12
273
+ # Entropy calculation
274
+ H_unc = -np.sum(P * np.log(P + 1e-12), axis=1)
275
+ return dists, P, H_unc
276
 
277
+ def fit_global_anchors(pool: np.ndarray, K: int) -> np.ndarray:
278
+ km = KMeans(n_clusters=K, n_init="auto", random_state=42)
279
+ km.fit(pool)
280
+ return km.cluster_centers_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
 
 
 
282
 
283
+ # ====== Model I/O (hidden states) =============================================================
284
  @dataclass
285
  class HiddenStatesBundle:
286
  """
 
319
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
320
  return HiddenStatesBundle(hidden_layers=hs, tokens=tokens)
321
 
322
+
323
+ # ====== LoT-style anchors & features ==========================================================
324
  def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray:
325
  """
326
  Fit KMeans cluster centroids on a pooled set of states (from many layers/texts).
 
332
  kmeans.fit(all_states_sampled)
333
  return kmeans.cluster_centers_ # (K, D)
334
 
 
335
  def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
336
  """
337
  For states H (N,D) and anchors A (K,D):
 
350
  P = np.exp(logits)
351
  P /= P.sum(axis=1, keepdims=True) + 1e-12
352
  # Uncertainty (entropy)
353
+ H_unc = -np.sum(P * np.log(P + 1e-12), axis=1)
354
+
355
  return dists, P, H_unc
356
 
357
+
358
+ # ====== Dimensionality reduction / embeddings ================================================
359
  def fit_umap_2d(pool: np.ndarray,
360
  n_neighbors: int = 30,
361
  min_dist: float = 0.05,
 
371
  reducer.fit(pool)
372
  return reducer
373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
  def fit_umap_3d(all_states: np.ndarray,
376
  n_neighbors: int = 30,
 
385
  metric=metric, random_state=random_state)
386
  return reducer.fit_transform(all_states)
387
 
388
+ """## Define Visualization Function"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
+ # ====== Visualization ========================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  def plotly_3d_layers(xy_layers: List[np.ndarray],
392
  layer_tokens: List[List[str]],
393
  layer_cluster_labels: List[np.ndarray],
394
+ layer_pos_tags: List[List[str]],
395
  layer_uncertainty: List[np.ndarray],
396
  layer_graphs: List[nx.Graph],
397
+ color_by: str = "cluster",
398
+ title: str = "3D Cluster Formation",
399
+ prompt: str = None,) -> go.Figure:
400
+
 
 
 
 
 
401
  fig_data = []
402
 
403
+ # Define categorical colormap for POS
404
+ pos_map = {
405
+ "NOUN": "#1f77b4", "VERB": "#d62728", "ADJ": "#2ca02c",
406
+ "ADV": "#ff7f0e", "PRON": "#9467bd", "DET": "#8c564b",
407
+ "ADP": "#e377c2", "NUM": "#7f7f7f", "PUNCT": "#bcbd22",
408
+ "SYM": "#17becf", "UNK": "#bababa"
409
+ }
410
+
411
+ L = len(xy_layers)
412
+ for l, (xy, tokens, labels, pos, unc, G) in enumerate(zip(xy_layers, layer_tokens, layer_cluster_labels, layer_pos_tags, layer_uncertainty, layer_graphs)):
413
+ if len(xy) == 0: continue
414
  x, y = xy[:, 0], xy[:, 1]
415
  z = np.full_like(x, l, dtype=float)
416
 
417
+ # Color Logic
418
+ if color_by == "pos":
419
+ # Map POS strings to colors
420
+ node_colors = [pos_map.get(p, "#333333") for p in pos]
421
+ show_scale = False
422
+ colorscale = None
423
+ else:
424
+ # Cluster ID
425
+ node_colors = labels
426
+ show_scale = (l == 0)
427
+ colorscale = 'Viridis'
428
+
429
+ # Hover Text
430
+ node_text = [
431
+ f"L{l} | {tok}<br>POS: {p}<br>Cluster: {c}<br>Unc: {u:.2f}"
432
+ for tok, p, c, u in zip(tokens, pos, labels, unc)
433
+ ]
434
+
435
  node_trace = go.Scatter3d(
436
  x=x, y=y, z=z,
437
  mode='markers',
438
  name=f"Layer {l}",
439
+ showlegend=False,
440
  marker=dict(
441
+ size=3,
442
+ opacity=1,
443
+ color=node_colors,
444
+ colorscale=colorscale,
445
+ showscale=show_scale,
446
+ colorbar=dict(title="Cluster ID") if show_scale else None
447
  ),
448
  text=node_text,
449
  hovertemplate="%{text}<extra></extra>"
450
  )
451
  fig_data.append(node_trace)
452
 
453
+ # Edges
454
  if G is not None and G.number_of_edges() > 0:
455
  edge_x, edge_y, edge_z = [], [], []
456
  for u, v in G.edges():
457
  edge_x += [x[u], x[v], None]
458
  edge_y += [y[u], y[v], None]
459
  edge_z += [z[u], z[v], None]
460
+
461
  edge_trace = go.Scatter3d(
462
  x=edge_x, y=edge_y, z=edge_z,
463
  mode='lines',
464
+ line=dict(width=2, color='red'),
465
+ opacity=0.6,
466
+ hoverinfo='skip',
467
+ showlegend=False
468
  )
469
  fig_data.append(edge_trace)
470
 
471
+ # Trajectories (connect same token across layers)
472
+ if L > 1:
473
+ T = len(xy_layers[0])
474
+ # Sample trajectories to avoid lag if T is huge
475
+ step = max(1, T // 100)
476
+ for i in range(0, T, step):
477
+ xs = [xy_layers[l][i, 0] for l in range(L)]
478
+ ys = [xy_layers[l][i, 1] for l in range(L)]
479
+ zs = list(range(L))
480
+ traj = go.Scatter3d(
481
+ x=xs, y=ys, z=zs,
482
+ mode='lines',
483
+ line=dict(width=3, color='rgba(50,50,50,0.5)'),
484
+ hoverinfo='skip',
485
+ showlegend=False
486
+ )
487
+ fig_data.append(traj)
488
+ if color_by == "pos":
489
+ # Add legend-only traces for POS categories actually present
490
+ present_pos = sorted({p for layer in layer_pos_tags for p in layer})
491
+
492
+ for p in present_pos:
493
+ fig_data.append(
494
+ go.Scatter3d(
495
+ x=[None], y=[None], z=[None], # legend-only
496
+ mode="markers",
497
+ name=p,
498
+ marker=dict(size=8, color=pos_map.get(p, "#333333")),
499
+ showlegend=True,
500
+ hoverinfo="skip"
501
  )
502
+ )
503
 
504
  fig = go.Figure(data=fig_data)
505
  fig.update_layout(
506
+ title=dict(
507
+ text=title,
508
+ x=0.5,
509
+ xanchor="center",
510
+ ),
511
+ annotations=[
512
+ dict(
513
+ text=f"<b>Prompt:</b> {prompt}",
514
+ x=0.5,
515
+ y=1.02,
516
+ xref="paper",
517
+ yref="paper",
518
+ showarrow=False,
519
+ font=dict(size=13),
520
+ align="center"
521
+ )
522
+ ] if prompt else [],
523
  scene=dict(
524
  xaxis_title="UMAP X",
525
  yaxis_title="UMAP Y",
526
+ zaxis_title="Layer Depth",
527
+ aspectratio=dict(x=1, y=1, z=1.5)
528
  ),
529
  height=900,
530
+ margin=dict(l=0, r=0, b=0, t=40)
531
  )
532
  return fig
533
 
534
+ """## Building the pipeline"""
535
+
536
+ def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False):
537
  seed_everything(42)
538
 
539
+ # 1. Extract Hidden States
540
+ from transformers import logging
541
+ logging.set_verbosity_error()
 
542
 
543
+ # Extract
544
  main_bundle = extract_hidden_states(model, tok, main_text, cfg.max_length, device)
545
+ layers_np = main_bundle.hidden_layers
546
+ tokens = main_bundle.tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  L_all = len(layers_np)
 
548
 
549
+ # 2. Get POS Tags
550
+ pos_tags = get_pos_tags(main_text, tok, tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
 
552
+ # 3. Pooling & Anchors (LoT)
553
+ # (Simplified: just pool from the main text for speed in demo)
554
+ pool_states = np.vstack([layers_np[l] for l in range(0, L_all, 2)])
555
+ idx = np.random.choice(len(pool_states), min(len(pool_states), 2000), replace=False)
556
+ anchors = fit_global_anchors(pool_states[idx], cfg.anchor_k)
 
557
 
558
+ # 4. Process Layers
559
+ layer_features = []
560
+ layer_uncertainties = []
561
  layer_graphs = []
562
+ layer_cluster_labels = []
563
+ percolation = []
564
+
565
  for l in range(L_all):
566
+ H = layers_np[l]
567
+
568
+ # Features & Uncertainty
569
+ dists, P, H_unc = anchor_features(H, anchors, cfg.anchor_temp)
570
+ layer_features.append(dists)
571
+ layer_uncertainties.append(H_unc)
572
+
573
+ # Graphs
574
  if cfg.graph_mode == "knn":
575
+ G = build_knn_graph(dists, cfg.knn_k, metric="euclidean")
576
  else:
577
+ G = build_threshold_graph(H, cfg.sim_threshold, use_cosine=cfg.use_cosine)
 
578
  layer_graphs.append(G)
579
 
580
+ # Clusters
581
+ labels = cluster_layer(dists, G, cfg.cluster_method,
582
+ cfg.n_clusters_kmeans, cfg.hdbscan_min_cluster_size)
 
 
 
 
 
 
 
 
583
  layer_cluster_labels.append(labels)
584
 
585
+ # Percolation
586
+ percolation.append(percolation_stats(G))
587
+
588
+ # 5. UMAP & Alignment
589
+ # Fit UMAP on the pool to establish a coordinate system
590
+ reducer = umap.UMAP(n_components=2, n_neighbors=cfg.umap_n_neighbors,
591
+ min_dist=cfg.umap_min_dist, metric=cfg.umap_metric, random_state=42)
592
+ reducer.fit(pool_states[idx])
593
 
594
+ xy_by_layer = []
595
+ for l in range(L_all):
596
+ # Transform into 2D
597
+ xy = reducer.transform(layers_np[l])
598
 
599
+ # Procrustes Alignment: Align layer L to L-1
600
+ if cfg.align_layers and l > 0:
601
+ xy = orthogonal_align(xy_by_layer[l-1], xy)
 
 
 
602
 
603
+ xy_by_layer.append(xy)
 
 
604
 
605
+ # 6. Plot
606
  fig = plotly_3d_layers(
607
  xy_layers=xy_by_layer,
608
+ layer_tokens=[tokens] * L_all,
609
  layer_cluster_labels=layer_cluster_labels,
610
+ layer_pos_tags=[pos_tags] * L_all,
611
  layer_uncertainty=layer_uncertainties,
612
  layer_graphs=layer_graphs,
613
+ color_by=cfg.color_by,
614
+ title=f"{cfg.model_name.rsplit("/", 1)[-1]} 3D MRI | Color: {cfg.color_by.upper()} | Aligned: {cfg.align_layers}",
615
+ prompt=main_text
616
  )
617
 
618
+ # 7. Save Artifacts (This is the missing part)
619
  if save_artifacts:
620
+ import os
621
+ # Create the directory if it doesn't exist
622
+ os.makedirs(cfg.out_dir, exist_ok=True)
623
+
624
+ # Construct the full path
625
+ out_path = os.path.join(cfg.out_dir, cfg.plotly_html)
626
+
627
+ # Write the HTML file
628
+ fig.write_html(out_path)
629
+ print(f"Successfully saved 3D plot to: {out_path}")
630
 
631
  return fig, {"percolation": percolation, "tokens": tokens}
632
 
633
+ """## This section is for the Web App UI"""
634
+
635
  @st.cache_resource(show_spinner=False)
636
  def get_model_and_tok(model_name: str):
637
  device = "cuda" if torch.cuda.is_available() else "cpu"
638
  dtype = torch.float16 if device == "cuda" else torch.float32
639
+ config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, trust_remote_code=True)
640
+ tok = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
641
+ if tok.pad_token_id is None:
642
+ tok.pad_token = tok.eos_token
643
+
644
+ model = AutoModelForCausalLM.from_pretrained(
645
+ model_name,
646
+ trust_remote_code=True,
647
+ config=config,
648
+ torch_dtype=dtype if device == "cuda" else None,
649
+ device_map="auto" if device == "cuda" else None
650
+ )
651
+ model.eval()
652
+
653
+ if device != "cuda":
654
+ model = model.to(device)
655
+
656
  return model, tok, device, dtype
657
 
658
  def main():
659
+ st.set_page_config(page_title="LLM Hidden Layer Explorer", layout="wide")
660
+ st.title("Token Embedding Explorer (Live Hidden States)")
661
 
662
  with st.sidebar:
663
  st.header("Model / Input")
664
+ model_name = st.selectbox("Model", MODELS, index=1)
665
  max_length = st.slider("Max tokens", 16, 256, 64, step=16)
666
 
667
  st.header("Graph")
 
685
  st.header("Outputs")
686
  save_artifacts = st.checkbox("Save artifacts to disk (HTML/CSV/NPZ)", value=False)
687
 
688
+ prompt_col, run_col = st.columns([4, 1])
689
+
690
+ with prompt_col:
691
+ main_text = st.selectbox(
692
+ "Prompt to visualize (hidden states computed on this text)",
693
+ options=DEFAULT_CORPUS,
694
+ index=0,
695
+ help="Select a predefined prompt for analysis"
696
+ )
697
+
698
+ with run_col:
699
+ st.write("")
700
+ st.write("")
701
+ run_btn = st.button("Run", type="primary")
702
 
703
  cfg = Config(
704
  model_name=model_name,