Spaces:
Sleeping
Sleeping
| import os | |
| import math | |
| import json | |
| import warnings | |
| from dataclasses import dataclass, asdict | |
| from typing import Dict, List, Tuple, Optional | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch import nn | |
| import networkx as nx | |
| import streamlit as st | |
| # Transformers: Qwen tokenizer can be AutoTokenizer if Qwen2Tokenizer not present | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
| # Dimensionality reduction | |
| import umap | |
| from umap import UMAP | |
| # Neighbors & clustering | |
| from sklearn.neighbors import NearestNeighbors, KernelDensity | |
| from sklearn.cluster import KMeans, DBSCAN | |
| from sklearn.decomposition import PCA | |
| from sklearn.metrics import pairwise_distances | |
| # Plotly for interactive 3D | |
| import plotly.graph_objects as go | |
| import hashlib | |
| # Optional libs (use if present) | |
| try: | |
| import hdbscan # Robust density-based clustering | |
| HAS_HDBSCAN = True | |
| except Exception: | |
| HAS_HDBSCAN = False | |
| try: | |
| import igraph as ig | |
| import leidenalg as la | |
| HAS_IGRAPH_LEIDEN = True | |
| except Exception: | |
| HAS_IGRAPH_LEIDEN = False | |
| try: | |
| import pyvista as pv # Volume & isosurfaces (VTK) | |
| HAS_PYVISTA = True | |
| except Exception: | |
| HAS_PYVISTA = False | |
| from scipy.linalg import orthogonal_procrustes # For optional per-layer orientation alignment | |
| # ====== 1. Configuration ========================================================================= | |
| class Config: | |
| # Model | |
| model_name: str = "Qwen/Qwen1.5-1.8B" | |
| ### device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| ### dtype: torch.dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Tokenization / generation | |
| max_length: int = 64 # truncate inputs for speed/memory | |
| # Data | |
| corpus: List[str] = None # set below | |
| # If None, uses DEFAULT_CORPUS defined below | |
| # Graph building | |
| graph_mode: str = "threshold" # {"knn", "threshold"} | |
| knn_k: int = 8 # neighbors per token (used if graph_mode="knn") | |
| sim_threshold: float = 0.60 # used if graph_mode="threshold" | |
| use_cosine: bool = True | |
| # Anchors / LoT-style features (global) | |
| anchor_k: int = 16 # number of global prototypes (KMeans on pooled states) | |
| anchor_temp: float = 0.7 # softmax temperature for converting distances to probs | |
| # Clustering per layer | |
| cluster_method: str = "auto" # {"auto","leiden","hdbscan","dbscan","kmeans"} | |
| n_clusters_kmeans: int = 6 # fallback for kmeans | |
| hdbscan_min_cluster_size: int = 4 | |
| # DR / embeddings | |
| umap_n_neighbors: int = 30 | |
| umap_min_dist: float = 0.05 | |
| umap_metric: str = "cosine" # hidden states are directional β cosine works well | |
| use_global_3d_umap: bool = False # if True, compute a single 3D manifold for all states | |
| # Pooling for UMAP fit | |
| fit_pool_per_layer: int = 512 # number of states sampled per layer to fit UMAP | |
| # Volume grid (MRI view) | |
| grid_res: int = 128 # voxel resolution in x/y; z = num_layers | |
| kde_bandwidth: float = 0.15 # KDE bandwidth in manifold space (if using KDE) | |
| use_hist2d: bool = True # if True, use histogram2d instead of KDE for speed | |
| # Output | |
| out_dir: str = "qwen_mri3d_outputs" | |
| plotly_html: str = "qwen_layers_3d.html" | |
| volume_npz: str = "qwen_density_volume.npz" # saved if PyVista isn't available | |
| volume_screenshot: str = "qwen_volume.png" # if PyVista is available | |
| def validate(self): | |
| if self.graph_mode not in {"knn", "threshold"}: | |
| raise ValueError("graph_mode must be 'knn' or 'threshold'") | |
| if self.knn_k < 2: | |
| raise ValueError("knn_k must be >= 2") | |
| if self.anchor_k < 2: | |
| raise ValueError("anchor_k must be >= 2") | |
| if self.anchor_temp <= 0: | |
| raise ValueError("anchor_temp must be > 0") | |
| # Default corpus (small and diverse; adjust freely) | |
| DEFAULT_CORPUS = [ | |
| "The cat sat on the mat and watched.", | |
| "Machine learning models process data using neural networks.", | |
| "Climate change affects ecosystems around the world.", | |
| "Quantum computers use superposition for parallel computation.", | |
| "The universe contains billions of galaxies.", | |
| "Artificial intelligence transforms how we work.", | |
| "DNA stores genetic information in cells.", | |
| "Ocean currents regulate Earth's climate system.", | |
| "Photosynthesis converts sunlight into chemical energy.", | |
| "Blockchain technology enables decentralized systems." | |
| ] | |
| # ====== 2. Utilities ============================================================================= | |
| def seed_everything(seed: int = 42): | |
| """Determinism for reproducibility in layouts/UMAP/kmeans.""" | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| def cosine_similarity_matrix(X: np.ndarray) -> np.ndarray: | |
| """Compute pairwise cosine similarity for rows of X.""" | |
| # X: (N, D) | |
| norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-8 | |
| Xn = X / norms | |
| return Xn @ Xn.T | |
| def build_knn_graph(coords: np.ndarray, k: int, metric: str = "cosine") -> nx.Graph: | |
| """ | |
| Build an undirected kNN graph for the points in coords. | |
| coords: (N, D) | |
| """ | |
| nbrs = NearestNeighbors(n_neighbors=min(k+1, len(coords)), metric=metric) # +1 to include self | |
| nbrs.fit(coords) | |
| distances, indices = nbrs.kneighbors(coords) | |
| G = nx.Graph() | |
| G.add_nodes_from(range(len(coords))) | |
| # Connect i to its top-k neighbors (skip index 0 which is itself) | |
| for i in range(len(coords)): | |
| for j in indices[i, 1:]: # skip self | |
| G.add_edge(int(i), int(j)) | |
| return G | |
| def build_threshold_graph(H: np.ndarray, threshold: float, use_cosine: bool = True) -> nx.Graph: | |
| """ | |
| Build graph by thresholding pairwise similarities in the original hidden-state space. | |
| H: (N, D) hidden states for a single layer | |
| """ | |
| if use_cosine: | |
| S = cosine_similarity_matrix(H) | |
| else: | |
| S = H @ H.T # dot product | |
| N = S.shape[0] | |
| G = nx.Graph() | |
| G.add_nodes_from(range(N)) | |
| for i in range(N): | |
| for j in range(i + 1, N): | |
| if S[i, j] > threshold: | |
| G.add_edge(i, j, weight=float(S[i, j])) | |
| return G | |
| def percolation_stats(G: nx.Graph) -> Dict[str, float]: | |
| """ | |
| Compute percolation observables (Ο, #clusters, Ο) as in your notebook. | |
| Ο : fraction of nodes in the Giant Connected Component (GCC) | |
| Ο : mean size of components excluding GCC | |
| """ | |
| n = G.number_of_nodes() | |
| if n == 0: | |
| return dict(phi=0.0, num_clusters=0, chi=0.0, largest_component_size=0, component_sizes=[]) | |
| comps = list(nx.connected_components(G)) | |
| sizes = [len(c) for c in comps] | |
| if not sizes: | |
| return dict(phi=0.0, num_clusters=0, chi=0.0, largest_component_size=0, component_sizes=[]) | |
| largest = max(sizes) | |
| phi = largest / n | |
| non_gcc_sizes = [s for s in sizes if s != largest] | |
| chi = float(np.mean(non_gcc_sizes)) if non_gcc_sizes else 0.0 | |
| return dict(phi=float(phi), | |
| num_clusters=len(comps), | |
| chi=float(chi), | |
| largest_component_size=largest, | |
| component_sizes=sorted(sizes, reverse=True)) | |
| def leiden_communities(G: nx.Graph) -> np.ndarray: | |
| """ | |
| Community detection using Leiden (igraph), if available. | |
| Returns an array of cluster ids for nodes 0..N-1. | |
| """ | |
| if not HAS_IGRAPH_LEIDEN: | |
| raise RuntimeError("igraph+leidenalg not available") | |
| # Convert nx β igraph | |
| mapping = {n: i for i, n in enumerate(G.nodes())} | |
| edges = [(mapping[u], mapping[v]) for u, v in G.edges()] | |
| ig_g = ig.Graph(n=len(mapping), edges=edges, directed=False) | |
| part = la.find_partition(ig_g, la.RBConfigurationVertexPartition) # robust default | |
| labels = np.zeros(len(mapping), dtype=int) | |
| for cid, comm in enumerate(part): | |
| for node in comm: | |
| labels[node] = cid | |
| return labels | |
| def cluster_layer(features: np.ndarray, | |
| G: Optional[nx.Graph], | |
| method: str, | |
| n_clusters_kmeans: int = 6, | |
| hdbscan_min_cluster_size: int = 4) -> np.ndarray: | |
| """ | |
| Cluster layer states to get cluster labels. | |
| - If Leiden: requires G (graph) and igraph/leidenalg | |
| - If HDBSCAN: density-based clustering in feature space | |
| - If DBSCAN: fallback density-based (scikit-learn) | |
| - If KMeans: fallback centroid clustering | |
| """ | |
| method = method.lower() | |
| N = len(features) | |
| if method == "auto": | |
| # Prefer Leiden (graph) β HDBSCAN β KMeans | |
| if HAS_IGRAPH_LEIDEN and G is not None and G.number_of_edges() > 0: | |
| return leiden_communities(G) | |
| elif HAS_HDBSCAN and N >= 5: | |
| clusterer = hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size, | |
| metric='euclidean') | |
| labels = clusterer.fit_predict(features) | |
| # HDBSCAN: -1 = noise. Keep as its own "noise" cluster id or remap | |
| return labels | |
| else: | |
| km = KMeans(n_clusters=min(n_clusters_kmeans, max(2, N // 3)), | |
| n_init="auto", random_state=42) | |
| return km.fit_predict(features) | |
| if method == "leiden": | |
| if G is None or not HAS_IGRAPH_LEIDEN: | |
| raise RuntimeError("Leiden requires a graph and igraph+leidenalg.") | |
| return leiden_communities(G) | |
| if method == "hdbscan": | |
| if not HAS_HDBSCAN: | |
| raise RuntimeError("hdbscan not installed") | |
| clusterer = hdbscan.HDBSCAN(min_cluster_size=hdbscan_min_cluster_size, metric='euclidean') | |
| return clusterer.fit_predict(features) | |
| if method == "dbscan": | |
| db = DBSCAN(eps=0.5, min_samples=4, metric='euclidean') | |
| return db.fit_predict(features) | |
| if method == "kmeans": | |
| km = KMeans(n_clusters=min(n_clusters_kmeans, max(2, N // 3)), | |
| n_init="auto", random_state=42) | |
| return km.fit_predict(features) | |
| raise ValueError(f"Unknown cluster method: {method}") | |
| def orthogonal_align(A_ref: np.ndarray, B: np.ndarray) -> np.ndarray: | |
| """ | |
| Align B to A_ref by an orthogonal rotation (Procrustes), | |
| preserving geometry but removing arbitrary orientation flips. | |
| """ | |
| R, _ = orthogonal_procrustes(B - B.mean(0), A_ref - A_ref.mean(0)) | |
| return (B - B.mean(0)) @ R + A_ref.mean(0) | |
| def entropy_from_probs(p: np.ndarray, eps: float = 1e-12) -> np.ndarray: | |
| """Shannon entropy for each row; p is (N, K) with rows summing ~1.""" | |
| return -np.sum(p * np.log(p + eps), axis=1) | |
| # ====== 3. Model I/O (hidden states) ============================================================= | |
| class HiddenStatesBundle: | |
| """ | |
| Encapsulates a single input's hidden states and metadata. | |
| hidden_layers: list of np.ndarray of shape (T, D), length = num_layers+1 (incl. embedding) | |
| tokens : list of token strings of length T | |
| """ | |
| hidden_layers: List[np.ndarray] | |
| tokens: List[str] | |
| def load_qwen(model_name: str, device: str, dtype: torch.dtype): | |
| """ | |
| Load Qwen with output_hidden_states=True. We use AutoTokenizer for broader compatibility. | |
| """ | |
| print(f"[Load] {model_name} on {device} ({dtype})") | |
| config = AutoConfig.from_pretrained(model_name, output_hidden_states=True) | |
| tok = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, config=config) | |
| model.eval().to(device) | |
| if device == "cuda" and dtype == torch.float16: | |
| model = model.half() | |
| return model, tok | |
| def extract_hidden_states(model, tokenizer, text: str, max_length: int, device: str) -> HiddenStatesBundle: | |
| """ | |
| Run a single forward pass to collect all hidden states (incl. embedding layer). | |
| Returns CPU numpy arrays to keep GPU memory low. | |
| """ | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length).to(device) | |
| out = model(**inputs) | |
| # Tuple length = num_layers + 1 (embedding) | |
| hs = [h[0].detach().float().cpu().numpy() for h in out.hidden_states] # shapes: (T, D) | |
| tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
| return HiddenStatesBundle(hidden_layers=hs, tokens=tokens) | |
| # ====== 4. LoT-style anchors & features ========================================================== | |
| def fit_global_anchors(all_states_sampled: np.ndarray, K: int, random_state: int = 42) -> np.ndarray: | |
| """ | |
| Fit KMeans cluster centroids on a pooled set of states (from many layers/texts). | |
| These centroids are "anchors" (LoT-like choices) to build low-dim features: | |
| f(state) = [dist(state, anchor_j)]_{j=1..K} | |
| """ | |
| print(f"[Anchors] Fitting {K} global centroids on {len(all_states_sampled)} states ...") | |
| kmeans = KMeans(n_clusters=K, n_init="auto", random_state=random_state) | |
| kmeans.fit(all_states_sampled) | |
| return kmeans.cluster_centers_ # (K, D) | |
| def anchor_features(H: np.ndarray, anchors: np.ndarray, temperature: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| For states H (N,D) and anchors A (K,D): | |
| - Compute Euclidean distances to each anchor β Dists (N,K) | |
| - Convert to soft probabilities with exp(-Dist/T), normalize row-wise β P (N,K) | |
| - Uncertainty = entropy(P) (cf. LoT Eq. (6)) | |
| - Top-anchor argmin distance for "consistency"-style comparisons (cf. Eq. (5)) | |
| Returns (Dists, P, entropy) | |
| """ | |
| # Distances (N, K) | |
| dists = pairwise_distances(H, anchors, metric="euclidean") # (N,K) | |
| # Soft assignments | |
| logits = -dists / max(temperature, 1e-6) | |
| # Stable softmax | |
| logits = logits - logits.max(axis=1, keepdims=True) | |
| P = np.exp(logits) | |
| P /= P.sum(axis=1, keepdims=True) + 1e-12 | |
| # Uncertainty (entropy) | |
| H_unc = entropy_from_probs(P) | |
| return dists, P, H_unc | |
| # ====== 5. Dimensionality reduction / embeddings ================================================ | |
| def fit_umap_2d(pool: np.ndarray, | |
| n_neighbors: int = 30, | |
| min_dist: float = 0.05, | |
| metric: str = "cosine", | |
| random_state: int = 42) -> umap.UMAP: | |
| """ | |
| Fit UMAP once on a diverse pool across layers to preserve orientation. | |
| Later layers call .transform() to embed into the SAME 2D space β "MRI stack". | |
| """ | |
| reducer = umap.UMAP(n_components=2, n_neighbors=n_neighbors, min_dist=min_dist, | |
| metric=metric, random_state=random_state) | |
| reducer.fit(pool) | |
| return reducer | |
| def _corpus_fingerprint(texts, max_items=5, max_chars=4000) -> str: | |
| """Stable key so cache invalidates if DEFAULT_CORPUS changes.""" | |
| joined = "\n".join(texts[:max_items]) | |
| joined = joined[:max_chars] | |
| return hashlib.sha256(joined.encode("utf-8")).hexdigest() | |
| def get_pool_artifacts( | |
| model_name: str, | |
| max_length: int, | |
| anchor_k: int, | |
| anchor_temp: float, # not strictly needed for fitting anchors, but included if you want cache keys aligned | |
| umap_n_neighbors: int, | |
| umap_min_dist: float, | |
| umap_metric: str, | |
| fit_pool_per_layer: int, | |
| corpus_hash: str, | |
| ): | |
| """ | |
| Cached: build pooled hidden states on DEFAULT_CORPUS, fit anchors and a UMAP reducer once. | |
| Returns: | |
| anchors: (K, D) np.ndarray | |
| reducer2d: fitted UMAP reducer object (must be pickleable; umap-learn's UMAP is) | |
| """ | |
| # Use cached model loader (resource cache) | |
| model, tok, device, dtype = get_model_and_tok(model_name) | |
| texts = DEFAULT_CORPUS # pooled set for stability | |
| pool_states = [] | |
| for t in texts[: min(5, len(texts))]: | |
| b = extract_hidden_states(model, tok, t, max_length, device) | |
| for H in b.hidden_layers: | |
| T = len(H) | |
| take = min(fit_pool_per_layer, T) | |
| if take <= 0: | |
| continue | |
| idx = np.random.choice(T, size=take, replace=False) | |
| pool_states.append(H[idx]) | |
| if not pool_states: | |
| # fallback: this should rarely happen | |
| raise RuntimeError("Pool construction produced no states.") | |
| pool_states = np.vstack(pool_states) | |
| anchors = fit_global_anchors(pool_states, anchor_k) | |
| reducer2d = fit_umap_2d( | |
| pool_states, | |
| n_neighbors=umap_n_neighbors, | |
| min_dist=umap_min_dist, | |
| metric=umap_metric, | |
| ) | |
| return anchors, reducer2d | |
| def fit_umap_3d(all_states: np.ndarray, | |
| n_neighbors: int = 30, | |
| min_dist: float = 0.05, | |
| metric: str = "cosine", | |
| random_state: int = 42) -> np.ndarray: | |
| """ | |
| Fit a global 3D UMAP embedding for all states at once (alternative to slice stack). | |
| Returns coords_3d (N,3) for the concatenated states passed in. | |
| """ | |
| reducer = umap.UMAP(n_components=3, n_neighbors=n_neighbors, min_dist=min_dist, | |
| metric=metric, random_state=random_state) | |
| return reducer.fit_transform(all_states) | |
| # ====== 6. Volume construction (MRI) ============================================================ | |
| def stack_density_volume(xy_by_layer: List[np.ndarray], | |
| grid_res: int, | |
| use_hist2d: bool = True, | |
| kde_bandwidth: float = 0.15) -> np.ndarray: | |
| """ | |
| Construct a 3D volume by estimating 2D density on the (x,y) manifold per layer (slice). | |
| - If use_hist2d: fast uniform binning into grid_res x grid_res | |
| - Else: KDE (slower but smoother) | |
| Returns volume of shape (grid_res, grid_res, L) where L = #layers. | |
| """ | |
| L = len(xy_by_layer) | |
| vol = np.zeros((grid_res, grid_res, L), dtype=np.float32) | |
| # Determine global bounds across layers to keep axes consistent | |
| all_xy = np.vstack([xy for xy in xy_by_layer if len(xy) > 0]) if L > 0 else np.zeros((0, 2)) | |
| if len(all_xy) == 0: | |
| return vol | |
| x_min, y_min = all_xy.min(axis=0) | |
| x_max, y_max = all_xy.max(axis=0) | |
| # Slight padding | |
| pad = 1e-6 | |
| x_edges = np.linspace(x_min - pad, x_max + pad, grid_res + 1) | |
| y_edges = np.linspace(y_min - pad, y_max + pad, grid_res + 1) | |
| for l, XY in enumerate(xy_by_layer): | |
| if len(XY) == 0: | |
| continue | |
| if use_hist2d: | |
| H, _, _ = np.histogram2d(XY[:, 0], XY[:, 1], bins=[x_edges, y_edges], density=False) | |
| vol[:, :, l] = H.T # histogram2d returns [x_bins, y_bins] β transpose to align | |
| else: | |
| kde = KernelDensity(bandwidth=kde_bandwidth, kernel="gaussian") | |
| kde.fit(XY) | |
| # Evaluate KDE on grid centers | |
| xs = 0.5 * (x_edges[:-1] + x_edges[1:]) | |
| ys = 0.5 * (y_edges[:-1] + y_edges[1:]) | |
| xx, yy = np.meshgrid(xs, ys, indexing='xy') | |
| grid_points = np.column_stack([xx.ravel(), yy.ravel()]) | |
| log_dens = kde.score_samples(grid_points) | |
| dens = np.exp(log_dens).reshape(grid_res, grid_res) | |
| vol[:, :, l] = dens | |
| # Normalize volume to [0,1] for rendering convenience | |
| if vol.max() > 0: | |
| vol = vol / vol.max() | |
| return vol | |
| def render_volume_with_pyvista(volume: np.ndarray, | |
| out_png: str, | |
| opacity="sigmoid") -> None: | |
| """ | |
| Visualize the 3D volume using PyVista/VTK (if installed); save a screenshot. | |
| """ | |
| if not HAS_PYVISTA: | |
| raise RuntimeError("PyVista is not installed; cannot render volume.") | |
| pl = pv.Plotter() | |
| # Wrap NumPy array as a VTK image data; PyVista expects z as the 3rd axis | |
| vol_vtk = pv.wrap(volume) | |
| pl.add_volume(vol_vtk, opacity=opacity, shade=True) | |
| pl.show(screenshot=out_png) # headless environments will still save a screenshot (if offscreen support) | |
| # ====== 7. 3D Plotly visualization ============================================================== | |
| def plotly_3d_layers(xy_layers: List[np.ndarray], | |
| layer_tokens: List[List[str]], | |
| layer_cluster_labels: List[np.ndarray], | |
| layer_uncertainty: List[np.ndarray], | |
| layer_graphs: List[nx.Graph], | |
| connect_token_trajectories: bool = True, | |
| title: str = "Qwen: 3D Cluster Formation (UMAP2D + Layer as Z)") -> go.Figure: | |
| """ | |
| Build an interactive 3D Plotly figure: | |
| - Nodes per layer at (x, y, z=layer) | |
| - Edge segments (kNN or threshold graph) per layer | |
| - Trajectory lines: connect same token index across consecutive layers (optional) | |
| - Color nodes by cluster label; hover shows token & uncertainty | |
| """ | |
| fig_data = [] | |
| # Build a color per layer node trace | |
| for l, (xy, tokens, labels, unc, G) in enumerate(zip(xy_layers, layer_tokens, layer_cluster_labels, layer_uncertainty, layer_graphs)): | |
| if len(xy) == 0: | |
| continue | |
| x, y = xy[:, 0], xy[:, 1] | |
| z = np.full_like(x, l, dtype=float) | |
| # --- Nodes | |
| node_text = [f"layer={l} | idx={i}<br>token={tokens[i]}<br>cluster={int(labels[i])}<br>uncertainty={unc[i]:.3f}" | |
| for i in range(len(tokens))] | |
| node_trace = go.Scatter3d( | |
| x=x, y=y, z=z, | |
| mode='markers', | |
| name=f"Layer {l}", | |
| marker=dict( | |
| size=4, | |
| opacity=0.7, | |
| color=labels, # cluster ID β color scale | |
| colorscale='Viridis', | |
| showscale=(l == 0) # show scale once | |
| ), | |
| text=node_text, | |
| hovertemplate="%{text}<extra></extra>" | |
| ) | |
| fig_data.append(node_trace) | |
| # --- Intra-layer edges (kNN or threshold) | |
| if G is not None and G.number_of_edges() > 0: | |
| edge_x, edge_y, edge_z = [], [], [] | |
| for u, v in G.edges(): | |
| edge_x += [x[u], x[v], None] | |
| edge_y += [y[u], y[v], None] | |
| edge_z += [z[u], z[v], None] | |
| edge_trace = go.Scatter3d( | |
| x=edge_x, y=edge_y, z=edge_z, | |
| mode='lines', | |
| line=dict(width=1), | |
| opacity=0.30, | |
| name=f"Edges L{l}" | |
| ) | |
| fig_data.append(edge_trace) | |
| # --- Trajectories: connect same token index across layers | |
| if connect_token_trajectories: | |
| # Only meaningful if tokenization length T is constant across layers (it is) | |
| # We'll draw faint polylines for each position i across l=0..L-1 | |
| L = len(xy_layers) | |
| if L > 1: | |
| T = min(len(xy_layers[l]) for l in range(L)) | |
| for i in range(T): | |
| xs = [xy_layers[l][i, 0] for l in range(L)] | |
| ys = [xy_layers[l][i, 1] for l in range(L)] | |
| zs = list(range(L)) | |
| traj = go.Scatter3d( | |
| x=xs, y=ys, z=zs, | |
| mode='lines', | |
| line=dict(width=1), | |
| opacity=0.15, | |
| name=f"traj_{i}", | |
| hoverinfo='skip' | |
| ) | |
| fig_data.append(traj) | |
| fig = go.Figure(data=fig_data) | |
| fig.update_layout( | |
| title=title, | |
| scene=dict( | |
| xaxis_title="UMAP X", | |
| yaxis_title="UMAP Y", | |
| zaxis_title="Layer (depth)" | |
| ), | |
| height=900, | |
| showlegend=False | |
| ) | |
| return fig | |
| # ====== 8. Orchestration ======================================================================== | |
| def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts: bool = False): | |
| seed_everything(42) | |
| # 8.2 Collect hidden states for one representative text (detailed viz) + for pool | |
| # You can extend to many texts; we keep a single text for clarity & speed. | |
| texts = cfg.corpus or DEFAULT_CORPUS | |
| #print(f"[Input] Example text: {main_text!r}") | |
| # Hidden states for main text | |
| main_bundle = extract_hidden_states(model, tok, main_text, cfg.max_length, device) | |
| layers_np: List[np.ndarray] = main_bundle.hidden_layers # list of (T,D), length L_all = num_layers+1 | |
| tokens = main_bundle.tokens # list of length T | |
| # Cached pool artifacts (anchors + fitted UMAP reducer) | |
| corpus_hash = _corpus_fingerprint(texts) # texts is cfg.corpus or DEFAULT_CORPUS | |
| anchors, reducer2d = get_pool_artifacts( | |
| model_name=cfg.model_name, | |
| max_length=cfg.max_length, | |
| anchor_k=cfg.anchor_k, | |
| anchor_temp=cfg.anchor_temp, | |
| umap_n_neighbors=cfg.umap_n_neighbors, | |
| umap_min_dist=cfg.umap_min_dist, | |
| umap_metric=cfg.umap_metric, | |
| fit_pool_per_layer=cfg.fit_pool_per_layer, | |
| corpus_hash=corpus_hash, | |
| ) | |
| L_all = len(layers_np) | |
| #print(f"[Hidden] Layers (incl. embedding): {L_all}, Tokens: {len(tokens)}") | |
| """ | |
| # 8.3 Build a pool of states (across a few texts & layers) to fit anchors + UMAP | |
| pool_states = [] | |
| # Sample across first few texts to improve diversity (lightweight) | |
| for t in texts[: min(5, len(texts))]: | |
| b = extract_hidden_states(model, tok, t, cfg.max_length, device) | |
| # Take a subset from each layer to limit pool size | |
| for H in b.hidden_layers: | |
| T = len(H) | |
| take = min(cfg.fit_pool_per_layer, T) | |
| idx = np.random.choice(T, size=take, replace=False) | |
| pool_states.append(H[idx]) | |
| pool_states = np.vstack(pool_states) if len(pool_states) else layers_np[-1] | |
| #print(f"[Pool] Pooled states for anchors/UMAP: {pool_states.shape}") | |
| # 8.4 Fit global anchors (LoT-style features) | |
| anchors = fit_global_anchors(pool_states, cfg.anchor_k) | |
| # Save anchors for reproducibility | |
| """ | |
| # 8.5 Build per-layer features for main text (LoT-style distances & uncertainty) | |
| layer_features = [] # list of (T,K) | |
| layer_uncertainties = [] # list of (T,) | |
| layer_top_anchor = [] # list of (T,) argmin-id | |
| for l, H in enumerate(layers_np): | |
| dists, P, H_unc = anchor_features(H, anchors, cfg.anchor_temp) | |
| layer_features.append(dists) # N x K distances (lower = closer) | |
| layer_uncertainties.append(H_unc) # N | |
| layer_top_anchor.append(np.argmin(dists, axis=1)) # closest anchor id per token | |
| # 8.6 Consistency metric (LoT Eq. (5)): does layer's top anchor match final layer's? | |
| final_top = layer_top_anchor[-1] | |
| layer_consistency = [] | |
| for l in range(L_all): | |
| cons = (layer_top_anchor[l] == final_top).astype(np.int32) # 1 if matches, 0 otherwise | |
| layer_consistency.append(cons) | |
| # 8.7 Build per-layer graphs (kNN by default) on FEATURE space for stability | |
| layer_graphs = [] | |
| for l in range(L_all): | |
| feats = layer_features[l] | |
| if cfg.graph_mode == "knn": | |
| G = build_knn_graph(feats, cfg.knn_k, metric="euclidean") # kNN in feature space | |
| else: | |
| # Threshold graph in original hidden space (as in your notebook) | |
| G = build_threshold_graph(layers_np[l], cfg.sim_threshold, use_cosine=cfg.use_cosine) | |
| layer_graphs.append(G) | |
| # 8.8 Cluster per layer | |
| layer_cluster_labels = [] | |
| for l in range(L_all): | |
| feats = layer_features[l] | |
| labels = cluster_layer( | |
| feats, | |
| layer_graphs[l], | |
| method=cfg.cluster_method, | |
| n_clusters_kmeans=cfg.n_clusters_kmeans, | |
| hdbscan_min_cluster_size=cfg.hdbscan_min_cluster_size | |
| ) | |
| layer_cluster_labels.append(labels) | |
| # 8.9 Percolation statistics (Ο, #clusters, Ο) per layer (as in your notebook) | |
| percolation = [] | |
| for l in range(L_all): | |
| stats = percolation_stats(layer_graphs[l]) | |
| percolation.append(stats) | |
| # 8.10 Common 2D manifold via UMAP (fit-once on the pool), then transform each layer | |
| """reducer2d = fit_umap_2d(pool_states, | |
| n_neighbors=cfg.umap_n_neighbors, | |
| min_dist=cfg.umap_min_dist, | |
| metric=cfg.umap_metric)""" | |
| xy_by_layer = [reducer2d.transform(layers_np[l]) for l in range(L_all)] | |
| # OPTIONAL: orthogonal alignment across layers (helps if UMAP.transform still drifts) | |
| # for l in range(1, L_all): | |
| # xy_by_layer[l] = orthogonal_align(xy_by_layer[l-1], xy_by_layer[l]) | |
| # 8.11 Plotly 3D point+graph view: X,Y from UMAP; Z = layer index | |
| fig = plotly_3d_layers( | |
| xy_layers=xy_by_layer, | |
| layer_tokens=[tokens for _ in range(L_all)], | |
| layer_cluster_labels=layer_cluster_labels, | |
| layer_uncertainty=layer_uncertainties, | |
| layer_graphs=layer_graphs, | |
| connect_token_trajectories=True, | |
| title="Qwen: 3D Cluster Formation (UMAP2D + Layer as Z, LoT metrics on hover)" | |
| ) | |
| if save_artifacts: | |
| os.makedirs(cfg.out_dir, exist_ok=True) | |
| html_path = os.path.join(cfg.out_dir, cfg.plotly_html) | |
| fig.write_html(html_path) | |
| # Save percolation series | |
| with open(os.path.join(cfg.out_dir, "percolation_stats.json"), "w") as f: | |
| json.dump(percolation, f, indent=2) | |
| np.save(os.path.join(cfg.out_dir, "anchors.npy"), anchors) | |
| #print(f"[Percolation] Saved per-layer stats β percolation_stats.json") | |
| #print(f"[Plotly] 3D HTML saved β {html_path}") | |
| return fig, {"percolation": percolation, "tokens": tokens} | |
| def get_model_and_tok(model_name: str): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| model, tok = load_qwen(model_name, device, dtype) | |
| return model, tok, device, dtype | |
| def main(): | |
| st.set_page_config(page_title="Layer Explorer", layout="wide") | |
| st.title("3D Token Embedding Explorer (Live Hidden States)") | |
| with st.sidebar: | |
| st.header("Model / Input") | |
| model_name = st.selectbox("Model", ["Qwen/Qwen1.5-0.5B", "Qwen/Qwen1.5-1.8B", "Qwen/Qwen1.5-4B"], index=1) | |
| max_length = st.slider("Max tokens", 16, 256, 64, step=16) | |
| st.header("Graph") | |
| graph_mode = st.selectbox("Graph mode", ["knn", "threshold"], index=0) | |
| knn_k = st.slider("k (kNN)", 2, 50, 8) if graph_mode == "knn" else 8 | |
| sim_threshold = st.slider("Similarity threshold", 0.0, 0.99, 0.70, step=0.01) if graph_mode == "threshold" else 0.70 | |
| use_cosine = st.checkbox("Use cosine similarity", value=True) | |
| st.header("Anchors / LoT") | |
| anchor_k = st.slider("anchor_k", 4, 64, 16, step=1) | |
| anchor_temp = st.slider("anchor_temp", 0.05, 2.0, 0.7, step=0.05) | |
| st.header("UMAP") | |
| umap_n_neighbors = st.slider("n_neighbors", 5, 100, 30, step=1) | |
| umap_min_dist = st.slider("min_dist", 0.0, 0.99, 0.05, step=0.01) | |
| umap_metric = st.selectbox("metric", ["cosine", "euclidean"], index=0) | |
| st.header("Performance") | |
| fit_pool_per_layer = st.slider("fit_pool_per_layer", 64, 2048, 512, step=64) | |
| st.header("Outputs") | |
| save_artifacts = st.checkbox("Save artifacts to disk (HTML/CSV/NPZ)", value=False) | |
| prompt_col, run_col = st.columns([4, 1]) | |
| with prompt_col: | |
| main_text = st.text_area( | |
| "Text to visualize (hidden states computed on this text)", | |
| value="Explain in one sentence what a transformer attention layer does.", | |
| height=140 | |
| ) | |
| with run_col: | |
| st.write("") | |
| st.write("") | |
| run_btn = st.button("Run", type="primary") | |
| cfg = Config( | |
| model_name=model_name, | |
| max_length=max_length, | |
| corpus=None, # keep using DEFAULT_CORPUS for pooling unless you expose it | |
| graph_mode=graph_mode, | |
| knn_k=knn_k, | |
| sim_threshold=sim_threshold, | |
| use_cosine=use_cosine, | |
| anchor_k=anchor_k, | |
| anchor_temp=anchor_temp, | |
| umap_n_neighbors=umap_n_neighbors, | |
| umap_min_dist=umap_min_dist, | |
| umap_metric=umap_metric, | |
| fit_pool_per_layer=fit_pool_per_layer, | |
| # keep other defaults | |
| ) | |
| if run_btn: | |
| if not main_text.strip(): | |
| st.error("Please enter some text.") | |
| return | |
| with st.spinner("Loading model (cached after first run)..."): | |
| model, tok, device, dtype = get_model_and_tok(cfg.model_name) | |
| # optionally pass compute_volume to pipeline (recommended) | |
| # e.g., run_pipeline(..., compute_volume=compute_volume) | |
| with st.spinner("Running pipeline (hidden states β features β UMAP β Plotly)..."): | |
| fig, outputs = run_pipeline( | |
| cfg=cfg, | |
| model=model, | |
| tok=tok, | |
| device=device, | |
| main_text=main_text, | |
| save_artifacts=save_artifacts, | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.success(f"Loaded {cfg.model_name} on {device} ({dtype})") | |
| with st.expander("Percolation summary"): | |
| percolation = outputs.get("percolation", []) | |
| for l, stt in enumerate(percolation): | |
| st.write(f"L={l:02d} | Ο={stt['phi']:.3f} | #C={stt['num_clusters']} | Ο={stt['chi']:.2f}") | |
| with st.expander("Debug: config"): | |
| st.json(asdict(cfg)) | |
| # ====== 9. Main ================================================================================= | |
| if __name__ == "__main__": | |
| torch.set_grad_enabled(False) | |
| main() |