Joshua Gray commited on
Commit
de372eb
·
1 Parent(s): c304b46

UMAP/Pool Performance boost

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +98 -22
src/streamlit_app.py CHANGED
@@ -30,6 +30,8 @@ from sklearn.metrics import pairwise_distances
30
  # Plotly for interactive 3D
31
  import plotly.graph_objects as go
32
 
 
 
33
  # Optional libs (use if present)
34
  try:
35
  import hdbscan # Robust density-based clustering
@@ -384,6 +386,63 @@ def fit_umap_2d(pool: np.ndarray,
384
  reducer.fit(pool)
385
  return reducer
386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
  def fit_umap_3d(all_states: np.ndarray,
389
  n_neighbors: int = 30,
@@ -568,27 +627,44 @@ def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts
568
  main_bundle = extract_hidden_states(model, tok, main_text, cfg.max_length, device)
569
  layers_np: List[np.ndarray] = main_bundle.hidden_layers # list of (T,D), length L_all = num_layers+1
570
  tokens = main_bundle.tokens # list of length T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  L_all = len(layers_np)
572
  #print(f"[Hidden] Layers (incl. embedding): {L_all}, Tokens: {len(tokens)}")
573
 
574
- # 8.3 Build a pool of states (across a few texts & layers) to fit anchors + UMAP
575
- pool_states = []
576
- # Sample across first few texts to improve diversity (lightweight)
577
- for t in texts[: min(5, len(texts))]:
578
- b = extract_hidden_states(model, tok, t, cfg.max_length, device)
579
- # Take a subset from each layer to limit pool size
580
- for H in b.hidden_layers:
581
- T = len(H)
582
- take = min(cfg.fit_pool_per_layer, T)
583
- idx = np.random.choice(T, size=take, replace=False)
584
- pool_states.append(H[idx])
585
- pool_states = np.vstack(pool_states) if len(pool_states) else layers_np[-1]
586
- #print(f"[Pool] Pooled states for anchors/UMAP: {pool_states.shape}")
587
-
588
- # 8.4 Fit global anchors (LoT-style features)
589
- anchors = fit_global_anchors(pool_states, cfg.anchor_k)
590
- # Save anchors for reproducibility
591
-
 
592
 
593
  # 8.5 Build per-layer features for main text (LoT-style distances & uncertainty)
594
  layer_features = [] # list of (T,K)
@@ -640,10 +716,10 @@ def run_pipeline(cfg: Config, model, tok, device, main_text: str, save_artifacts
640
 
641
 
642
  # 8.10 Common 2D manifold via UMAP (fit-once on the pool), then transform each layer
643
- reducer2d = fit_umap_2d(pool_states,
644
  n_neighbors=cfg.umap_n_neighbors,
645
  min_dist=cfg.umap_min_dist,
646
- metric=cfg.umap_metric)
647
  xy_by_layer = [reducer2d.transform(layers_np[l]) for l in range(L_all)]
648
 
649
  # OPTIONAL: orthogonal alignment across layers (helps if UMAP.transform still drifts)
@@ -682,8 +758,8 @@ def get_model_and_tok(model_name: str):
682
  return model, tok, device, dtype
683
 
684
  def main():
685
- st.set_page_config(page_title="Qwen Layer Explorer", layout="wide")
686
- st.title("Qwen: 3D Token Embedding Explorer (Live Hidden States)")
687
 
688
  with st.sidebar:
689
  st.header("Model / Input")
 
30
  # Plotly for interactive 3D
31
  import plotly.graph_objects as go
32
 
33
+ import hashlib
34
+
35
  # Optional libs (use if present)
36
  try:
37
  import hdbscan # Robust density-based clustering
 
386
  reducer.fit(pool)
387
  return reducer
388
 
389
+ def _corpus_fingerprint(texts, max_items=5, max_chars=4000) -> str:
390
+ """Stable key so cache invalidates if DEFAULT_CORPUS changes."""
391
+ joined = "\n".join(texts[:max_items])
392
+ joined = joined[:max_chars]
393
+ return hashlib.sha256(joined.encode("utf-8")).hexdigest()
394
+
395
+ @st.cache_data(show_spinner=False)
396
+ def get_pool_artifacts(
397
+ model_name: str,
398
+ max_length: int,
399
+ anchor_k: int,
400
+ anchor_temp: float, # not strictly needed for fitting anchors, but included if you want cache keys aligned
401
+ umap_n_neighbors: int,
402
+ umap_min_dist: float,
403
+ umap_metric: str,
404
+ fit_pool_per_layer: int,
405
+ corpus_hash: str,
406
+ ):
407
+ """
408
+ Cached: build pooled hidden states on DEFAULT_CORPUS, fit anchors and a UMAP reducer once.
409
+ Returns:
410
+ anchors: (K, D) np.ndarray
411
+ reducer2d: fitted UMAP reducer object (must be pickleable; umap-learn's UMAP is)
412
+ """
413
+ # Use cached model loader (resource cache)
414
+ model, tok, device, dtype = get_model_and_tok(model_name)
415
+
416
+ texts = DEFAULT_CORPUS # pooled set for stability
417
+
418
+ pool_states = []
419
+ for t in texts[: min(5, len(texts))]:
420
+ b = extract_hidden_states(model, tok, t, max_length, device)
421
+ for H in b.hidden_layers:
422
+ T = len(H)
423
+ take = min(fit_pool_per_layer, T)
424
+ if take <= 0:
425
+ continue
426
+ idx = np.random.choice(T, size=take, replace=False)
427
+ pool_states.append(H[idx])
428
+
429
+ if not pool_states:
430
+ # fallback: this should rarely happen
431
+ raise RuntimeError("Pool construction produced no states.")
432
+
433
+ pool_states = np.vstack(pool_states)
434
+
435
+ anchors = fit_global_anchors(pool_states, anchor_k)
436
+
437
+ reducer2d = fit_umap_2d(
438
+ pool_states,
439
+ n_neighbors=umap_n_neighbors,
440
+ min_dist=umap_min_dist,
441
+ metric=umap_metric,
442
+ )
443
+
444
+ return anchors, reducer2d
445
+
446
 
447
  def fit_umap_3d(all_states: np.ndarray,
448
  n_neighbors: int = 30,
 
627
  main_bundle = extract_hidden_states(model, tok, main_text, cfg.max_length, device)
628
  layers_np: List[np.ndarray] = main_bundle.hidden_layers # list of (T,D), length L_all = num_layers+1
629
  tokens = main_bundle.tokens # list of length T
630
+
631
+ # Cached pool artifacts (anchors + fitted UMAP reducer)
632
+ corpus_hash = _corpus_fingerprint(texts) # texts is cfg.corpus or DEFAULT_CORPUS
633
+
634
+ anchors, reducer2d = get_pool_artifacts(
635
+ model_name=cfg.model_name,
636
+ max_length=cfg.max_length,
637
+ anchor_k=cfg.anchor_k,
638
+ anchor_temp=cfg.anchor_temp,
639
+ umap_n_neighbors=cfg.umap_n_neighbors,
640
+ umap_min_dist=cfg.umap_min_dist,
641
+ umap_metric=cfg.umap_metric,
642
+ fit_pool_per_layer=cfg.fit_pool_per_layer,
643
+ corpus_hash=corpus_hash,
644
+ )
645
+
646
  L_all = len(layers_np)
647
  #print(f"[Hidden] Layers (incl. embedding): {L_all}, Tokens: {len(tokens)}")
648
 
649
+ """
650
+ # 8.3 Build a pool of states (across a few texts & layers) to fit anchors + UMAP
651
+ pool_states = []
652
+ # Sample across first few texts to improve diversity (lightweight)
653
+ for t in texts[: min(5, len(texts))]:
654
+ b = extract_hidden_states(model, tok, t, cfg.max_length, device)
655
+ # Take a subset from each layer to limit pool size
656
+ for H in b.hidden_layers:
657
+ T = len(H)
658
+ take = min(cfg.fit_pool_per_layer, T)
659
+ idx = np.random.choice(T, size=take, replace=False)
660
+ pool_states.append(H[idx])
661
+ pool_states = np.vstack(pool_states) if len(pool_states) else layers_np[-1]
662
+ #print(f"[Pool] Pooled states for anchors/UMAP: {pool_states.shape}")
663
+
664
+ # 8.4 Fit global anchors (LoT-style features)
665
+ anchors = fit_global_anchors(pool_states, cfg.anchor_k)
666
+ # Save anchors for reproducibility
667
+ """
668
 
669
  # 8.5 Build per-layer features for main text (LoT-style distances & uncertainty)
670
  layer_features = [] # list of (T,K)
 
716
 
717
 
718
  # 8.10 Common 2D manifold via UMAP (fit-once on the pool), then transform each layer
719
+ """reducer2d = fit_umap_2d(pool_states,
720
  n_neighbors=cfg.umap_n_neighbors,
721
  min_dist=cfg.umap_min_dist,
722
+ metric=cfg.umap_metric)"""
723
  xy_by_layer = [reducer2d.transform(layers_np[l]) for l in range(L_all)]
724
 
725
  # OPTIONAL: orthogonal alignment across layers (helps if UMAP.transform still drifts)
 
758
  return model, tok, device, dtype
759
 
760
  def main():
761
+ st.set_page_config(page_title="Layer Explorer", layout="wide")
762
+ st.title("3D Token Embedding Explorer (Live Hidden States)")
763
 
764
  with st.sidebar:
765
  st.header("Model / Input")