Jgray21 commited on
Commit
71d5f5e
·
verified ·
1 Parent(s): e340310

Upload app.py

Browse files
Files changed (1) hide show
  1. src/app.py +780 -0
src/app.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import json
4
+ import warnings
5
+ from dataclasses import dataclass, asdict
6
+ from typing import Dict, List, Tuple, Optional
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ import networkx as nx
15
+ import streamlit as st
16
+
17
+ # Transformers: Qwen tokenizer can be AutoTokenizer if Qwen2Tokenizer not present
18
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
19
+
20
+ # Dimensionality reduction
21
+ import umap
22
+ from umap import UMAP
23
+
24
+ # Neighbors & clustering
25
+ from sklearn.neighbors import NearestNeighbors, KernelDensity
26
+ from sklearn.cluster import KMeans, DBSCAN
27
+ from sklearn.decomposition import PCA
28
+ from sklearn.metrics import pairwise_distances
29
+
30
+ # Plotly for interactive 3D
31
+ import plotly.graph_objects as go
32
+
33
+ # Optional libs (use if present)
34
+ try:
35
+ import hdbscan # Robust density-based clustering
36
+ HAS_HDBSCAN = True
37
+ except Exception:
38
+ HAS_HDBSCAN = False
39
+
40
+ try:
41
+ import igraph as ig
42
+ import leidenalg as la
43
+ HAS_IGRAPH_LEIDEN = True
44
+ except Exception:
45
+ HAS_IGRAPH_LEIDEN = False
46
+
47
+ try:
48
+ import pyvista as pv # Volume & isosurfaces (VTK)
49
+ HAS_PYVISTA = True
50
+ except Exception:
51
+ HAS_PYVISTA = False
52
+
53
+ from scipy.linalg import orthogonal_procrustes # For optional per-layer orientation alignment
54
+
55
+ # ====== 1. Configuration =========================================================================
56
+ @dataclass
57
+ class Config:
58
+ # Model
59
+ model_name: str = "Qwen/Qwen1.5-1.8B"
60
+ ### device: str = "cuda" if torch.cuda.is_available() else "cpu"
61
+ ### dtype: torch.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
62
+
63
+ # Tokenization / generation
64
+ max_length: int = 64 # truncate inputs for speed/memory
65
+
66
+ # Data
67
+ corpus: List[str] = None # set below
68
+ # If None, uses DEFAULT_CORPUS defined below
69
+
70
+ # Graph building
71
+ graph_mode: str = "threshold" # {"knn", "threshold"}
72
+ knn_k: int = 8 # neighbors per token (used if graph_mode="knn")
73
+ sim_threshold: float = 0.60 # used if graph_mode="threshold"
74
+ use_cosine: bool = True
75
+
76
+ # Anchors / LoT-style features (global)
77
+ anchor_k: int = 16 # number of global prototypes (KMeans on pooled states)
78
+ anchor_temp: float = 0.7 # softmax temperature for converting distances to probs
79
+
80
+ # Clustering per layer
81
+ cluster_method: str = "auto" # {"auto","leiden","hdbscan","dbscan","kmeans"}
82
+ n_clusters_kmeans: int = 6 # fallback for kmeans
83
+ hdbscan_min_cluster_size: int = 4
84
+
85
+ # DR / embeddings
86
+ umap_n_neighbors: int = 30
87
+ umap_min_dist: float = 0.05
88
+ umap_metric: str = "cosine" # hidden states are directional → cosine works well
89
+ use_global_3d_umap: bool = False # if True, compute a single 3D manifold for all states
90
+
91
+ # Pooling for UMAP fit
92
+ fit_pool_per_layer: int = 512 # number of states sampled per layer to fit UMAP
93
+
94
+ # Volume grid (MRI view)
95
+ grid_res: int = 128 # voxel resolution in x/y; z = num_layers
96
+ kde_bandwidth: float = 0.15 # KDE bandwidth in manifold space (if using KDE)
97
+ use_hist2d: bool = True # if True, use histogram2d instead of KDE for speed
98
+
99
+ # Output
100
+ out_dir: str = "qwen_mri3d_outputs"
101
+ plotly_html: str = "qwen_layers_3d.html"
102
+ volume_npz: str = "qwen_density_volume.npz" # saved if PyVista isn't available
103
+ volume_screenshot: str = "qwen_volume.png" # if PyVista is available
104
+
105
+ def validate(self):
106
+ if self.graph_mode not in {"knn", "threshold"}:
107
+ raise ValueError("graph_mode must be 'knn' or 'threshold'")
108
+ if self.knn_k < 2:
109
+ raise ValueError("knn_k must be >= 2")
110
+ if self.anchor_k < 2:
111
+ raise ValueError("anchor_k must be >= 2")
112
+ if self.anchor_temp <= 0:
113
+ raise ValueError("anchor_temp must be > 0")
114
+
115
+
116
+
117
+ # Default corpus (small and diverse; adjust freely)
118
+ DEFAULT_CORPUS = [
119
+ "The cat sat on the mat and watched.",
120
+ "Machine learning models process data using neural networks.",
121
+ "Climate change affects ecosystems around the world.",
122
+ "Quantum computers use superposition for parallel computation.",
123
+ "The universe contains billions of galaxies.",
124
+ "Artificial intelligence transforms how we work.",
125
+ "DNA stores genetic information in cells.",
126
+ "Ocean currents regulate Earth's climate system.",
127
+ "Photosynthesis converts sunlight into chemical energy.",
128
+ "Blockchain technology enables decentralized systems."
129
+ ]
130
+
131
+ # ====== 2. Utilities =============================================================================
132
+ def seed_everything(seed: int = 42):
133
+ """Determinism for reproducibility in layouts/UMAP/kmeans."""
134
+ np.random.seed(seed)
135
+ torch.manual_seed(seed)
136
+
137
+
138
+ def cosine_similarity_matrix(X: np.ndarray) -> np.ndarray:
139
+ """Compute pairwise cosine similarity for rows of X."""
140
+ # X: (N, D)
141
+ norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-8
142
+ Xn = X / norms
143
+ return Xn @ Xn.T
144
+
145
+
146
+ def build_knn_graph(coords: np.ndarray, k: int, metric: str = "cosine") -> nx.Graph:
147
+ """
148
+ Build an undirected kNN graph for the points in coords.
149
+ coords: (N, D)
150
+ """
151
+ nbrs = NearestNeighbors(n_neighbors=min(k+1, len(coords)), metric=metric) # +1 to include self
152
+ nbrs.fit(coords)
153
+ distances, indices = nbrs.kneighbors(coords)
154
+
155
+ G = nx.Graph()
156
+ G.add_nodes_from(range(len(coords)))
157
+ # Connect i to its top-k neighbors (skip index 0 which is itself)
158
+ for i in range(len(coords)):
159
+ for j in indices[i, 1:]: # skip self
160
+ G.add_edge(int(i), int(j))
161
+ return G
162
+
163
+
164
+ def build_threshold_graph(H: np.ndarray, threshold: float, use_cosine: bool = True) -> nx.Graph:
165
+ """
166
+ Build graph by thresholding pairwise similarities in the original hidden-state space.
167
+ H: (N, D) hidden states for a single layer
168
+ """
169
+ if use_cosine:
170
+ S = cosine_similarity_matrix(H)
171
+ else:
172
+ S = H @ H.T # dot product
173
+
174
+ N = S.shape[0]
175
+ G = nx.Graph()
176
+ G.add_nodes_from(range(N))
177
+ for i in range(N):
178
+ for j in range(i + 1, N):
179
+ if S[i, j] > threshold:
180
+ G.add_edge(i, j, weight=float(S[i, j]))
181
+ return G
182
+
183
+
184
+ def percolation_stats(G: nx.Graph) -> Dict[str, float]:
185
+ """
186
+ Compute percolation observables (φ, #clusters, χ) as in your notebook.
187
+ φ : fraction of nodes in the Giant Connected Component (GCC)
188
+ χ : mean size of components excluding GCC
189
+ """
190
+ n = G.number_of_nodes()
191
+ if n == 0:
192
+ return dict(phi=0.0, num_clusters=0, chi=0.0, largest_component_size=0, component_sizes=[])
193
+
194
+ comps = list(nx.connected_components(G))
195
+ sizes = [len(c) for c in comps]
196
+ if not sizes:
197
+ return dict(phi=0.0, num_clusters=0, chi=0.0, largest_component_size=0, component_sizes=[])
198
+
199
+ largest = max(sizes)
200
+ phi = largest / n
201
+
202
+ non_gcc_sizes = [s for s in sizes if s != largest]
203
+ chi = float(np.mean(non_gcc_sizes)) if non_gcc_sizes else 0.0
204
+
205
+ return dict(phi=float(phi),
206
+ num_clusters=len(comps),
207
+ chi=float(chi),
208
+ largest_component_size=largest,
209
+ component_sizes=sorted(sizes, reverse=True))
210
+
211
+
212
+ def leiden_communities(G: nx.Graph) -> np.ndarray:
213
+ """
214
+ Community detection using Leiden (igraph), if available.
215
+ Returns an array of cluster ids for nodes 0..N-1.
216
+ """
217
+ if not HAS_IGRAPH_LEIDEN:
218
+ raise RuntimeError("igraph+leidenalg not available")
219
+
220
+ # Convert nx → igraph
221
+ mapping = {n: i for i, n in enumerate(G.nodes())}
222
+ edges = [(mapping[u], mapping[v]) for u, v in G.edges()]
223
+ ig_g = ig.Graph(n=len(mapping), edges=edges, directed=False)
224
+ part = la.find_partition(ig_g, la.RBConfigurationVertexPartition) # robust default
225
+ labels = np.zeros(len(mapping), dtype=int)
226
+ for cid, comm in enumerate(part):
227
+ for node in comm:
228
+ labels[node] = cid
229
+ return labels
230
+
231
+
232
+ def cluster_layer(features: np.ndarray,
233
+ G: Optional[nx.Graph],
234
+ method: str,
235
+ n_clusters_kmeans: int = 6,
236
+ hdbscan_min_cluster_size: int = 4) -> np.ndarray:
237
+ """
238
+ Cluster layer states to get cluster labels.
239
+ - If Leiden: requires G (graph) and igraph/leidenalg
240
+ - If HDBSCAN: density-based clustering in feature space
241
+ - If DBSCAN: fallback density-based (scikit-learn)
242
+ - If KMeans: fallback centroid clustering
243
+ """
244
+ method = method.lower()
245
+ N = len(features)
246
+
247
+ if method == "auto":
248
+ # Prefer Leiden (graph) → HDBSCAN → KMeans
249
+ if HAS_IGRAPH_LEIDEN and G is not None and G.number_of_edges() > 0:
250
+ return leiden_communities(G)
251
+ elif HAS_HDBSCAN and N >= 5:
252
+ clusterer = hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size,
253
+ metric='euclidean')
254
+ labels = clusterer.fit_predict(features)
255
+ # HDBSCAN: -1 = noise. Keep as its own "noise" cluster id or remap
256
+ return labels
257
+ else:
258
+ km = KMeans(n_clusters=min(n_clusters_kmeans, max(2, N // 3)),
259
+ n_init="auto", random_state=42)
260
+ return km.fit_predict(features)
261
+
262
+ if method == "leiden":
263
+ if G is None or not HAS_IGRAPH_LEIDEN:
264
+ raise RuntimeError("Leiden requires a graph and igraph+leidenalg.")
265
+ return leiden_communities(G)
266
+
267
+ if method == "hdbscan":
268
+ if not HAS_HDBSCAN:
269
+ raise RuntimeError("hdbscan not installed")
270
+ clusterer = hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size, metric='euclidean')
271
+ return clusterer.fit_predict(features)
272
+
273
+ if method == "dbscan":
274
+ db = DBSCAN(eps=0.5, min_samples=4, metric='euclidean')
275
+ return db.fit_predict(features)
276
+
277
+ if method == "kmeans":
278
+ km = KMeans(n_clusters=min(n_clusters_kmeans, max(2, N // 3)),
279
+ n_init="auto", random_state=42)
280
+ return km.fit_predict(features)
281
+
282
+ raise ValueError(f"Unknown cluster method: {method}")
283
+
284
+
285
+ def orthogonal_align(A_ref: np.ndarray, B: np.ndarray) -> np.ndarray:
286
+ """
287
+ Align B to A_ref by an orthogonal rotation (Procrustes),
288
+ preserving geometry but removing arbitrary orientation flips.
289
+ """
290
+ R, _ = orthogonal_procrustes(B - B.mean(0), A_ref - A_ref.mean(0))
291
+ return (B - B.mean(0)) @ R + A_ref.mean(0)
292
+
293
+
294
+ def entropy_from_probs(p: np.ndarray, eps: float = 1e-12) -> np.ndarray:
295
+ """Shannon entropy for each row; p is (N, K) with rows summing ~1."""
296
+ return -np.sum(p * np.log(p + eps), axis=1)
297
+
298
+ # ====== 3. Model I/O (hidden states) =============================================================
299
+ @dataclass
300
+ class HiddenStatesBundle:
301
+ """
302
+ Encapsulates a single input's hidden states and metadata.
303
+ hidden_layers: list of np.ndarray of shape (T, D), length = num_layers+1 (incl. embedding)
304
+ tokens : list of token strings of length T
305
+ """
306
+ hidden_layers: List[np.ndarray]
307
+ tokens: List[str]
308
+
309
+
310
+ def load_qwen(model_name: str, device: str, dtype: torch.dtype):
311
+ """
312
+ Load Qwen with output_hidden_states=True. We use AutoTokenizer for broader compatibility.
313
+ """
314
+ print(f"[Load] {model_name} on {device} ({dtype})")
315
+ config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
316
+ tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
317
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
318
+ model.eval().to(device)
319
+ if device == "cuda" and dtype == torch.float16:
320
+ model = model.half()
321
+ return model, tok
322
+
323
+
324
+ @torch.no_grad()
325
+ def extract_hidden_states(model, tokenizer, text: str, max_length: int, device: str) -> HiddenStatesBundle:
326
+ """
327
+ Run a single forward pass to collect all hidden states (incl. embedding layer).
328
+ Returns CPU numpy arrays to keep GPU memory low.
329
+ """
330
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
331
+ out = model(**inputs)
332
+ # Tuple length = num_layers + 1 (embedding)
333
+ hs = [h[0].detach().float().cpu().numpy() for h in out.hidden_states] # shapes: (T, D)
334
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
335
+ return HiddenStatesBundle(hidden_layers=hs, tokens=tokens)
336
+
337
+ # ====== 4. LoT-style anchors & features ==========================================================
338
+ def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray:
339
+ """
340
+ Fit KMeans cluster centroids on a pooled set of states (from many layers/texts).
341
+ These centroids are "anchors" (LoT-like choices) to build low-dim features:
342
+ f(state) = [dist(state, anchor_j)]_{j=1..K}
343
+ """
344
+ print(f"[Anchors] Fitting {K} global centroids on {len(all_states_sampled)} states ...")
345
+ kmeans = KMeans(n_clusters=K, n_init="auto", random_state=random_state)
346
+ kmeans.fit(all_states_sampled)
347
+ return kmeans.cluster_centers_ # (K, D)
348
+
349
+
350
+ def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
351
+ """
352
+ For states H (N,D) and anchors A (K,D):
353
+ - Compute Euclidean distances to each anchor → Dists (N,K)
354
+ - Convert to soft probabilities with exp(-Dist/T), normalize row-wise → P (N,K)
355
+ - Uncertainty = entropy(P) (cf. LoT Eq. (6))
356
+ - Top-anchor argmin distance for "consistency"-style comparisons (cf. Eq. (5))
357
+ Returns (Dists, P, entropy)
358
+ """
359
+ # Distances (N, K)
360
+ dists = pairwise_distances(H, anchors, metric="euclidean") # (N,K)
361
+ # Soft assignments
362
+ logits = -dists / max(temperature, 1e-6)
363
+ # Stable softmax
364
+ logits = logits - logits.max(axis=1, keepdims=True)
365
+ P = np.exp(logits)
366
+ P /= P.sum(axis=1, keepdims=True) + 1e-12
367
+ # Uncertainty (entropy)
368
+ H_unc = entropy_from_probs(P)
369
+ return dists, P, H_unc
370
+
371
+ # ====== 5. Dimensionality reduction / embeddings ================================================
372
+ def fit_umap_2d(pool: np.ndarray,
373
+ n_neighbors: int = 30,
374
+ min_dist: float = 0.05,
375
+ metric: str = "cosine",
376
+ random_state: int = 42) -> umap.UMAP:
377
+ """
378
+ Fit UMAP once on a diverse pool across layers to preserve orientation.
379
+ Later layers call .transform() to embed into the SAME 2D space → "MRI stack".
380
+ """
381
+
382
+ reducer = umap.UMAP(n_components=2, n_neighbors=n_neighbors, min_dist=min_dist,
383
+ metric=metric, random_state=random_state)
384
+ reducer.fit(pool)
385
+ return reducer
386
+
387
+
388
+ def fit_umap_3d(all_states: np.ndarray,
389
+ n_neighbors: int = 30,
390
+ min_dist: float = 0.05,
391
+ metric: str = "cosine",
392
+ random_state: int = 42) -> np.ndarray:
393
+ """
394
+ Fit a global 3D UMAP embedding for all states at once (alternative to slice stack).
395
+ Returns coords_3d (N,3) for the concatenated states passed in.
396
+ """
397
+ reducer = umap.UMAP(n_components=3, n_neighbors=n_neighbors, min_dist=min_dist,
398
+ metric=metric, random_state=random_state)
399
+ return reducer.fit_transform(all_states)
400
+
401
+ # ====== 6. Volume construction (MRI) ============================================================
402
+ def stack_density_volume(xy_by_layer: List[np.ndarray],
403
+ grid_res: int,
404
+ use_hist2d: bool = True,
405
+ kde_bandwidth: float = 0.15) -> np.ndarray:
406
+ """
407
+ Construct a 3D volume by estimating 2D density on the (x,y) manifold per layer (slice).
408
+ - If use_hist2d: fast uniform binning into grid_res x grid_res
409
+ - Else: KDE (slower but smoother)
410
+ Returns volume of shape (grid_res, grid_res, L) where L = #layers.
411
+ """
412
+ L = len(xy_by_layer)
413
+ vol = np.zeros((grid_res, grid_res, L), dtype=np.float32)
414
+
415
+ # Determine global bounds across layers to keep axes consistent
416
+ all_xy = np.vstack([xy for xy in xy_by_layer if len(xy) > 0]) if L > 0 else np.zeros((0, 2))
417
+ if len(all_xy) == 0:
418
+ return vol
419
+ x_min, y_min = all_xy.min(axis=0)
420
+ x_max, y_max = all_xy.max(axis=0)
421
+ # Slight padding
422
+ pad = 1e-6
423
+ x_edges = np.linspace(x_min - pad, x_max + pad, grid_res + 1)
424
+ y_edges = np.linspace(y_min - pad, y_max + pad, grid_res + 1)
425
+
426
+ for l, XY in enumerate(xy_by_layer):
427
+ if len(XY) == 0:
428
+ continue
429
+
430
+ if use_hist2d:
431
+ H, _, _ = np.histogram2d(XY[:, 0], XY[:, 1], bins=[x_edges, y_edges], density=False)
432
+ vol[:, :, l] = H.T # histogram2d returns [x_bins, y_bins] → transpose to align
433
+ else:
434
+ kde = KernelDensity(bandwidth=kde_bandwidth, kernel="gaussian")
435
+ kde.fit(XY)
436
+ # Evaluate KDE on grid centers
437
+ xs = 0.5 * (x_edges[:-1] + x_edges[1:])
438
+ ys = 0.5 * (y_edges[:-1] + y_edges[1:])
439
+ xx, yy = np.meshgrid(xs, ys, indexing='xy')
440
+ grid_points = np.column_stack([xx.ravel(), yy.ravel()])
441
+ log_dens = kde.score_samples(grid_points)
442
+ dens = np.exp(log_dens).reshape(grid_res, grid_res)
443
+ vol[:, :, l] = dens
444
+
445
+ # Normalize volume to [0,1] for rendering convenience
446
+ if vol.max() > 0:
447
+ vol = vol / vol.max()
448
+ return vol
449
+
450
+
451
+ def render_volume_with_pyvista(volume: np.ndarray,
452
+ out_png: str,
453
+ opacity="sigmoid") -> None:
454
+ """
455
+ Visualize the 3D volume using PyVista/VTK (if installed); save a screenshot.
456
+ """
457
+ if not HAS_PYVISTA:
458
+ raise RuntimeError("PyVista is not installed; cannot render volume.")
459
+ pl = pv.Plotter()
460
+ # Wrap NumPy array as a VTK image data; PyVista expects z as the 3rd axis
461
+ vol_vtk = pv.wrap(volume)
462
+ pl.add_volume(vol_vtk, opacity=opacity, shade=True)
463
+ pl.show(screenshot=out_png) # headless environments will still save a screenshot (if offscreen support)
464
+
465
+ # ====== 7. 3D Plotly visualization ==============================================================
466
+ def plotly_3d_layers(xy_layers: List[np.ndarray],
467
+ layer_tokens: List[List[str]],
468
+ layer_cluster_labels: List[np.ndarray],
469
+ layer_uncertainty: List[np.ndarray],
470
+ layer_graphs: List[nx.Graph],
471
+ connect_token_trajectories: bool = True,
472
+ title: str = "Qwen: 3D Cluster Formation (UMAP2D + Layer as Z)") -> go.Figure:
473
+ """
474
+ Build an interactive 3D Plotly figure:
475
+ - Nodes per layer at (x, y, z=layer)
476
+ - Edge segments (kNN or threshold graph) per layer
477
+ - Trajectory lines: connect same token index across consecutive layers (optional)
478
+ - Color nodes by cluster label; hover shows token & uncertainty
479
+ """
480
+ fig_data = []
481
+
482
+ # Build a color per layer node trace
483
+ for l, (xy, tokens, labels, unc, G) in enumerate(zip(xy_layers, layer_tokens, layer_cluster_labels, layer_uncertainty, layer_graphs)):
484
+ if len(xy) == 0:
485
+ continue
486
+ x, y = xy[:, 0], xy[:, 1]
487
+ z = np.full_like(x, l, dtype=float)
488
+
489
+ # --- Nodes
490
+ node_text = [f"layer={l} | idx={i}<br>token={tokens[i]}<br>cluster={int(labels[i])}<br>uncertainty={unc[i]:.3f}"
491
+ for i in range(len(tokens))]
492
+ node_trace = go.Scatter3d(
493
+ x=x, y=y, z=z,
494
+ mode='markers',
495
+ name=f"Layer {l}",
496
+ marker=dict(
497
+ size=4,
498
+ opacity=0.7,
499
+ color=labels, # cluster ID → color scale
500
+ colorscale='Viridis',
501
+ showscale=(l == 0) # show scale once
502
+ ),
503
+ text=node_text,
504
+ hovertemplate="%{text}<extra></extra>"
505
+ )
506
+ fig_data.append(node_trace)
507
+
508
+ # --- Intra-layer edges (kNN or threshold)
509
+ if G is not None and G.number_of_edges() > 0:
510
+ edge_x, edge_y, edge_z = [], [], []
511
+ for u, v in G.edges():
512
+ edge_x += [x[u], x[v], None]
513
+ edge_y += [y[u], y[v], None]
514
+ edge_z += [z[u], z[v], None]
515
+ edge_trace = go.Scatter3d(
516
+ x=edge_x, y=edge_y, z=edge_z,
517
+ mode='lines',
518
+ line=dict(width=1),
519
+ opacity=0.30,
520
+ name=f"Edges L{l}"
521
+ )
522
+ fig_data.append(edge_trace)
523
+
524
+ # --- Trajectories: connect same token index across layers
525
+ if connect_token_trajectories:
526
+ # Only meaningful if tokenization length T is constant across layers (it is)
527
+ # We'll draw faint polylines for each position i across l=0..L-1
528
+ L = len(xy_layers)
529
+ if L > 1:
530
+ T = min(len(xy_layers[l]) for l in range(L))
531
+ for i in range(T):
532
+ xs = [xy_layers[l][i, 0] for l in range(L)]
533
+ ys = [xy_layers[l][i, 1] for l in range(L)]
534
+ zs = list(range(L))
535
+ traj = go.Scatter3d(
536
+ x=xs, y=ys, z=zs,
537
+ mode='lines',
538
+ line=dict(width=1),
539
+ opacity=0.15,
540
+ name=f"traj_{i}",
541
+ hoverinfo='skip'
542
+ )
543
+ fig_data.append(traj)
544
+
545
+ fig = go.Figure(data=fig_data)
546
+ fig.update_layout(
547
+ title=title,
548
+ scene=dict(
549
+ xaxis_title="UMAP X",
550
+ yaxis_title="UMAP Y",
551
+ zaxis_title="Layer (depth)"
552
+ ),
553
+ height=900,
554
+ showlegend=False
555
+ )
556
+ return fig
557
+
558
+ # ====== 8. Orchestration ========================================================================
559
+ def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False):
560
+ seed_everything(42)
561
+
562
+ # 8.2 Collect hidden states for one representative text (detailed viz) + for pool
563
+ # You can extend to many texts; we keep a single text for clarity & speed.
564
+ texts = cfg.corpus or DEFAULT_CORPUS
565
+ #print(f"[Input] Example text: {main_text!r}")
566
+
567
+ # Hidden states for main text
568
+ main_bundle = extract_hidden_states(model, tok, main_text, cfg.max_length, device)
569
+ layers_np: List[np.ndarray] = main_bundle.hidden_layers # list of (T,D), length L_all = num_layers+1
570
+ tokens = main_bundle.tokens # list of length T
571
+ L_all = len(layers_np)
572
+ #print(f"[Hidden] Layers (incl. embedding): {L_all}, Tokens: {len(tokens)}")
573
+
574
+ # 8.3 Build a pool of states (across a few texts & layers) to fit anchors + UMAP
575
+ pool_states = []
576
+ # Sample across first few texts to improve diversity (lightweight)
577
+ for t in texts[: min(5, len(texts))]:
578
+ b = extract_hidden_states(model, tok, t, cfg.max_length, device)
579
+ # Take a subset from each layer to limit pool size
580
+ for H in b.hidden_layers:
581
+ T = len(H)
582
+ take = min(cfg.fit_pool_per_layer, T)
583
+ idx = np.random.choice(T, size=take, replace=False)
584
+ pool_states.append(H[idx])
585
+ pool_states = np.vstack(pool_states) if len(pool_states) else layers_np[-1]
586
+ #print(f"[Pool] Pooled states for anchors/UMAP: {pool_states.shape}")
587
+
588
+ # 8.4 Fit global anchors (LoT-style features)
589
+ anchors = fit_global_anchors(pool_states, cfg.anchor_k)
590
+ # Save anchors for reproducibility
591
+
592
+
593
+ # 8.5 Build per-layer features for main text (LoT-style distances & uncertainty)
594
+ layer_features = [] # list of (T,K)
595
+ layer_uncertainties = [] # list of (T,)
596
+ layer_top_anchor = [] # list of (T,) argmin-id
597
+
598
+ for l, H in enumerate(layers_np):
599
+ dists, P, H_unc = anchor_features(H, anchors, cfg.anchor_temp)
600
+ layer_features.append(dists) # N x K distances (lower = closer)
601
+ layer_uncertainties.append(H_unc) # N
602
+ layer_top_anchor.append(np.argmin(dists, axis=1)) # closest anchor id per token
603
+
604
+ # 8.6 Consistency metric (LoT Eq. (5)): does layer's top anchor match final layer's?
605
+ final_top = layer_top_anchor[-1]
606
+ layer_consistency = []
607
+ for l in range(L_all):
608
+ cons = (layer_top_anchor[l] == final_top).astype(np.int32) # 1 if matches, 0 otherwise
609
+ layer_consistency.append(cons)
610
+
611
+ # 8.7 Build per-layer graphs (kNN by default) on FEATURE space for stability
612
+ layer_graphs = []
613
+ for l in range(L_all):
614
+ feats = layer_features[l]
615
+ if cfg.graph_mode == "knn":
616
+ G = build_knn_graph(feats, cfg.knn_k, metric="euclidean") # kNN in feature space
617
+ else:
618
+ # Threshold graph in original hidden space (as in your notebook)
619
+ G = build_threshold_graph(layers_np[l], cfg.sim_threshold, use_cosine=cfg.use_cosine)
620
+ layer_graphs.append(G)
621
+
622
+ # 8.8 Cluster per layer
623
+ layer_cluster_labels = []
624
+ for l in range(L_all):
625
+ feats = layer_features[l]
626
+ labels = cluster_layer(
627
+ feats,
628
+ layer_graphs[l],
629
+ method=cfg.cluster_method,
630
+ n_clusters_kmeans=cfg.n_clusters_kmeans,
631
+ hdbscan_min_cluster_size=cfg.hdbscan_min_cluster_size
632
+ )
633
+ layer_cluster_labels.append(labels)
634
+
635
+ # 8.9 Percolation statistics (φ, #clusters, χ) per layer (as in your notebook)
636
+ percolation = []
637
+ for l in range(L_all):
638
+ stats = percolation_stats(layer_graphs[l])
639
+ percolation.append(stats)
640
+
641
+
642
+ # 8.10 Common 2D manifold via UMAP (fit-once on the pool), then transform each layer
643
+ reducer2d = fit_umap_2d(pool_states,
644
+ n_neighbors=cfg.umap_n_neighbors,
645
+ min_dist=cfg.umap_min_dist,
646
+ metric=cfg.umap_metric)
647
+ xy_by_layer = [reducer2d.transform(layers_np[l]) for l in range(L_all)]
648
+
649
+ # OPTIONAL: orthogonal alignment across layers (helps if UMAP.transform still drifts)
650
+ # for l in range(1, L_all):
651
+ # xy_by_layer[l] = orthogonal_align(xy_by_layer[l-1], xy_by_layer[l])
652
+
653
+ # 8.11 Plotly 3D point+graph view: X,Y from UMAP; Z = layer index
654
+ fig = plotly_3d_layers(
655
+ xy_layers=xy_by_layer,
656
+ layer_tokens=[tokens for _ in range(L_all)],
657
+ layer_cluster_labels=layer_cluster_labels,
658
+ layer_uncertainty=layer_uncertainties,
659
+ layer_graphs=layer_graphs,
660
+ connect_token_trajectories=True,
661
+ title="Qwen: 3D Cluster Formation (UMAP2D + Layer as Z, LoT metrics on hover)"
662
+ )
663
+
664
+ if save_artifacts:
665
+ os.makedirs(cfg.out_dir, exist_ok=True)
666
+ html_path = os.path.join(cfg.out_dir, cfg.plotly_html)
667
+ fig.write_html(html_path)
668
+ # Save percolation series
669
+ with open(os.path.join(cfg.out_dir, "percolation_stats.json"), "w") as f:
670
+ json.dump(percolation, f, indent=2)
671
+ np.save(os.path.join(cfg.out_dir, "anchors.npy"), anchors)
672
+ #print(f"[Percolation] Saved per-layer stats → percolation_stats.json")
673
+ #print(f"[Plotly] 3D HTML saved → {html_path}")
674
+
675
+ return fig, {"percolation": percolation, "tokens": tokens}
676
+
677
+ @st.cache_resource(show_spinner=False)
678
+ def get_model_and_tok(model_name: str):
679
+ device = "cuda" if torch.cuda.is_available() else "cpu"
680
+ dtype = torch.float16 if device == "cuda" else torch.float32
681
+ model, tok = load_qwen(model_name, device, dtype)
682
+ return model, tok, device, dtype
683
+
684
+ def main():
685
+ st.set_page_config(page_title="Qwen Layer Explorer", layout="wide")
686
+ st.title("Qwen: 3D Token Embedding Explorer (Live Hidden States)")
687
+
688
+ with st.sidebar:
689
+ st.header("Model / Input")
690
+ model_name = st.selectbox("Model", ["Qwen/Qwen1.5-0.5B", "Qwen/Qwen1.5-1.8B", "Qwen/Qwen1.5-4B"], index=1)
691
+ max_length = st.slider("Max tokens", 16, 256, 64, step=16)
692
+
693
+ st.header("Graph")
694
+ graph_mode = st.selectbox("Graph mode", ["knn", "threshold"], index=0)
695
+ knn_k = st.slider("k (kNN)", 2, 50, 8) if graph_mode == "knn" else 8
696
+ sim_threshold = st.slider("Similarity threshold", 0.0, 0.99, 0.70, step=0.01) if graph_mode == "threshold" else 0.70
697
+ use_cosine = st.checkbox("Use cosine similarity", value=True)
698
+
699
+ st.header("Anchors / LoT")
700
+ anchor_k = st.slider("anchor_k", 4, 64, 16, step=1)
701
+ anchor_temp = st.slider("anchor_temp", 0.05, 2.0, 0.7, step=0.05)
702
+
703
+ st.header("UMAP")
704
+ umap_n_neighbors = st.slider("n_neighbors", 5, 100, 30, step=1)
705
+ umap_min_dist = st.slider("min_dist", 0.0, 0.99, 0.05, step=0.01)
706
+ umap_metric = st.selectbox("metric", ["cosine", "euclidean"], index=0)
707
+
708
+ st.header("Performance")
709
+ fit_pool_per_layer = st.slider("fit_pool_per_layer", 64, 2048, 512, step=64)
710
+
711
+ st.header("Outputs")
712
+ save_artifacts = st.checkbox("Save artifacts to disk (HTML/CSV/NPZ)", value=False)
713
+
714
+ prompt_col, run_col = st.columns([4, 1])
715
+ with prompt_col:
716
+ main_text = st.text_area(
717
+ "Text to visualize (hidden states computed on this text)",
718
+ value="Explain in one sentence what a transformer attention layer does.",
719
+ height=140
720
+ )
721
+ with run_col:
722
+ st.write("")
723
+ st.write("")
724
+ run_btn = st.button("Run", type="primary")
725
+
726
+ cfg = Config(
727
+ model_name=model_name,
728
+ max_length=max_length,
729
+ corpus=None, # keep using DEFAULT_CORPUS for pooling unless you expose it
730
+ graph_mode=graph_mode,
731
+ knn_k=knn_k,
732
+ sim_threshold=sim_threshold,
733
+ use_cosine=use_cosine,
734
+ anchor_k=anchor_k,
735
+ anchor_temp=anchor_temp,
736
+ umap_n_neighbors=umap_n_neighbors,
737
+ umap_min_dist=umap_min_dist,
738
+ umap_metric=umap_metric,
739
+ fit_pool_per_layer=fit_pool_per_layer,
740
+ # keep other defaults
741
+ )
742
+
743
+ if run_btn:
744
+ if not main_text.strip():
745
+ st.error("Please enter some text.")
746
+ return
747
+
748
+ with st.spinner("Loading model (cached after first run)..."):
749
+ model, tok, device, dtype = get_model_and_tok(cfg.model_name)
750
+
751
+ # optionally pass compute_volume to pipeline (recommended)
752
+ # e.g., run_pipeline(..., compute_volume=compute_volume)
753
+ with st.spinner("Running pipeline (hidden states → features → UMAP → Plotly)..."):
754
+ fig, outputs = run_pipeline(
755
+ cfg=cfg,
756
+ model=model,
757
+ tok=tok,
758
+ device=device,
759
+ main_text=main_text,
760
+ save_artifacts=save_artifacts,
761
+ )
762
+
763
+ st.plotly_chart(fig, use_container_width=True)
764
+
765
+ st.success(f"Loaded {cfg.model_name} on {device} ({dtype})")
766
+
767
+
768
+ with st.expander("Percolation summary"):
769
+ percolation = outputs.get("percolation", [])
770
+ for l, stt in enumerate(percolation):
771
+ st.write(f"L={l:02d} | φ={stt['phi']:.3f} | #C={stt['num_clusters']} | χ={stt['chi']:.2f}")
772
+
773
+ with st.expander("Debug: config"):
774
+ st.json(asdict(cfg))
775
+
776
+
777
+ # ====== 9. Main =================================================================================
778
+ if __name__ == "__main__":
779
+ torch.set_grad_enabled(False)
780
+ main()