the-puzzler commited on
Commit
082f6c4
·
1 Parent(s): 6ea791c

Color UMAPs by final logit

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -516,7 +516,7 @@ def _infer_logits_and_final_embeddings(input_embeddings: np.ndarray, models: Loa
516
  return logits.squeeze(0).detach().cpu().numpy(), final_hidden.squeeze(0).detach().cpu().numpy()
517
 
518
 
519
- def _plot_umap(vectors: np.ndarray, labels: List[str], title: str):
520
  if len(vectors) < 2:
521
  raise gr.Error("UMAP needs at least 2 sequences.")
522
 
@@ -530,10 +530,9 @@ def _plot_umap(vectors: np.ndarray, labels: List[str], title: str):
530
  init="random" if n_points <= 3 else "spectral",
531
  )
532
  coords = reducer.fit_transform(vectors)
533
- norms = np.linalg.norm(vectors, axis=1)
534
  x_values = [float(value) for value in coords[:, 0]]
535
  y_values = [float(value) for value in coords[:, 1]]
536
- color_values = [float(value) for value in norms]
537
 
538
  fig = go.Figure(
539
  data=[
@@ -543,7 +542,7 @@ def _plot_umap(vectors: np.ndarray, labels: List[str], title: str):
543
  mode="markers",
544
  text=labels,
545
  customdata=np.array(color_values).reshape(-1, 1),
546
- hovertemplate="<b>%{text}</b><br>UMAP 1=%{x:.3f}<br>UMAP 2=%{y:.3f}<br>norm=%{customdata[0]:.3f}<extra></extra>",
547
  marker={
548
  "size": 10,
549
  "color": color_values,
@@ -551,7 +550,7 @@ def _plot_umap(vectors: np.ndarray, labels: List[str], title: str):
551
  "line": {"width": 0.6, "color": "#1d2a1f"},
552
  "opacity": 0.92,
553
  "showscale": True,
554
- "colorbar": {"title": "vector norm"},
555
  },
556
  )
557
  ]
@@ -620,8 +619,8 @@ def _analyze_records(records: List[dict], source_title: str, extra_summary: str
620
  input_embeddings = _embed_sequences(seqs, models)
621
  logits, final_embeddings = _infer_logits_and_final_embeddings(input_embeddings, models)
622
 
623
- input_umap = _plot_umap(input_embeddings, labels, "UMAP of Input DNA Embeddings")
624
- final_umap = _plot_umap(final_embeddings, labels, "UMAP of Final Transformer Embeddings")
625
  logits_hist = _plot_logits(logits, labels)
626
 
627
  rows = []
 
516
  return logits.squeeze(0).detach().cpu().numpy(), final_hidden.squeeze(0).detach().cpu().numpy()
517
 
518
 
519
+ def _plot_umap(vectors: np.ndarray, labels: List[str], logits: np.ndarray, title: str):
520
  if len(vectors) < 2:
521
  raise gr.Error("UMAP needs at least 2 sequences.")
522
 
 
530
  init="random" if n_points <= 3 else "spectral",
531
  )
532
  coords = reducer.fit_transform(vectors)
 
533
  x_values = [float(value) for value in coords[:, 0]]
534
  y_values = [float(value) for value in coords[:, 1]]
535
+ color_values = [float(value) for value in logits]
536
 
537
  fig = go.Figure(
538
  data=[
 
542
  mode="markers",
543
  text=labels,
544
  customdata=np.array(color_values).reshape(-1, 1),
545
+ hovertemplate="<b>%{text}</b><br>UMAP 1=%{x:.3f}<br>UMAP 2=%{y:.3f}<br>logit=%{customdata[0]:.4f}<extra></extra>",
546
  marker={
547
  "size": 10,
548
  "color": color_values,
 
550
  "line": {"width": 0.6, "color": "#1d2a1f"},
551
  "opacity": 0.92,
552
  "showscale": True,
553
+ "colorbar": {"title": "final logit"},
554
  },
555
  )
556
  ]
 
619
  input_embeddings = _embed_sequences(seqs, models)
620
  logits, final_embeddings = _infer_logits_and_final_embeddings(input_embeddings, models)
621
 
622
+ input_umap = _plot_umap(input_embeddings, labels, logits, "UMAP of Input DNA Embeddings")
623
+ final_umap = _plot_umap(final_embeddings, labels, logits, "UMAP of Final Transformer Embeddings")
624
  logits_hist = _plot_logits(logits, labels)
625
 
626
  rows = []