genomenet Claude Opus 4.7 (1M context) commited on
Commit
63e92bf
·
1 Parent(s): b9b5e8a

Redesign Compare tab as Lookup; add UniRef metadata; update header

Browse files

- Rename "Compare" to "Lookup" and strip it down: input sequence +
reference-set radio (UniRef50 / CRISPR) + top-k + button -> single
hits table. Heatmaps, stats bars, activation histogram, and .npy
downloads are gone; users who need those should use the Distance
tab or the API.
- Add fetch_uniref_metadata() that calls the UniProt uniref endpoint
to return "cluster name — organism — N members" for each hit. In-
memory cache; 5 s timeout; benign fallback string on failure. The
hits table now has a "description" column so users can see what
each UniRef50 cluster actually is, not just the accession.
- Update the app header to describe functional distance prediction
rather than "compare protein embeddings".
- Add requests to requirements.txt (explicit dep, for the metadata
fetcher).
- CRISPR reference option is exposed in the radio but the handler
currently returns a "not yet available" message since the Step 6
SLURM build is queued.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +137 -45
  2. requirements.txt +1 -0
app.py CHANGED
@@ -216,8 +216,53 @@ def compute_distance(seq_a, seq_b, aspect):
216
  l2 = float(torch.norm(a - b, p=2, dim=-1).item())
217
  return {"l2": l2, "cos_sim": cos_sim, "cos_dist": 1.0 - cos_sim}
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  def search_faiss(esm2_embedding, k=10):
220
- """Search FAISS index for top-k UniRef50 neighbors. Returns a DataFrame."""
 
221
  import faiss
222
 
223
  index, ids, _ = get_faiss()
@@ -225,17 +270,23 @@ def search_faiss(esm2_embedding, k=10):
225
  faiss.normalize_L2(q)
226
  scores, idxs = index.search(q, k)
227
 
228
- rows = []
 
229
  for rank, (score, i) in enumerate(zip(scores[0], idxs[0]), 1):
230
  if i < 0:
231
  continue
232
  uid = ids[i].decode() if isinstance(ids[i], (bytes, np.bytes_)) else str(ids[i])
233
- link = f"https://www.uniprot.org/uniref/{uid}"
 
 
 
 
234
  rows.append({
235
  "rank": rank,
236
  "uniref50_id": uid,
237
- "cosine": round(float(score), 4),
238
- "uniprot": link,
 
239
  })
240
  return pd.DataFrame(rows)
241
 
@@ -526,7 +577,12 @@ with gr.Blocks(
526
  title="Functional Distance",
527
  css=".gradio-container { max-width: 100% !important; }"
528
  ) as demo:
529
- gr.Markdown("# functional-distance\nCompare protein embeddings: ESM2 (pretrained) vs Twin (GO fine-tuned)")
 
 
 
 
 
530
 
531
  with gr.Tab("Distance"):
532
  gr.Markdown(
@@ -661,56 +717,92 @@ with gr.Blocks(
661
  label="Example pairs (CRISPR / anti-CRISPR) — click to load",
662
  )
663
 
664
- with gr.Tab("Compare"):
 
 
 
 
 
 
 
 
665
  with gr.Row():
666
- with gr.Column(scale=1, min_width=300):
667
- seq_input = gr.Textbox(
668
  label="Protein Sequence",
669
  placeholder="Paste protein sequence (amino acids)...",
670
- lines=5,
671
  value=EXAMPLE_PROTEIN,
672
- info="FASTA or raw sequence, max 1022 aa"
673
  )
674
- top_k_slider = gr.Slider(
675
- minimum=1, maximum=50, value=10, step=1,
676
- label="Nearest neighbors (top-k)"
 
677
  )
678
- twin_aspect_radio = gr.Radio(
679
- choices=["BP", "CC", "MF"],
680
- value=TWIN_DEFAULT_ASPECT,
681
- label="Twin GO aspect",
682
- info="Biological Process (BP), Cellular Component (CC), or Molecular Function (MF). "
683
- "First switch loads the aspect's checkpoint (~15 s)."
 
 
 
684
  )
685
- btn = gr.Button("compare embeddings", variant="primary")
686
- output = gr.Markdown()
687
- with gr.Row():
688
- esm2_download = gr.File(label="ESM2 .npy")
689
- twin_download = gr.File(label="Twin .npy")
690
 
691
- with gr.Column(scale=2, min_width=400):
692
- comparison_plot = gr.Plot(label="Statistics Comparison")
693
- distribution_plot = gr.Plot(label="Activation Distributions")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
 
695
- with gr.Row():
696
- esm2_heatmap = gr.Plot(label="ESM2 Embedding (1280-dim)")
697
- twin_heatmap = gr.Plot(label="Twin Embedding (1024-dim)")
698
-
699
- gr.Markdown("### Nearest UniRef50 neighbors (ESM2 embedding, cosine)")
700
- hits_table = gr.Dataframe(
701
- headers=["rank", "uniref50_id", "cosine", "uniprot"],
702
- datatype=["number", "str", "number", "str"],
703
- label="Top hits in GO-annotated UniRef50 (FAISS)",
704
- wrap=True,
 
705
  )
706
 
707
- btn.click(
708
- process,
709
- inputs=[seq_input, top_k_slider, twin_aspect_radio],
710
- outputs=[output, esm2_download, twin_download, esm2_heatmap, twin_heatmap, comparison_plot, distribution_plot, hits_table],
711
- api_name="compare"
712
- )
713
-
714
  with gr.Tab("API"):
715
  gr.Markdown("""
716
  ### API
 
216
  l2 = float(torch.norm(a - b, p=2, dim=-1).item())
217
  return {"l2": l2, "cos_sim": cos_sim, "cos_dist": 1.0 - cos_sim}
218
 
219
+ _uniref_meta_cache = {}
220
+
221
+
222
+ def fetch_uniref_metadata(uniref_ids):
223
+ """Fetch cluster name + representative organism for a list of UniRef50 IDs.
224
+
225
+ Uses the UniProt uniref endpoint. Results are cached in memory for the life
226
+ of the Space process. Returns dict: {id -> "protein name — organism"}.
227
+ Falls back to "" for any id that cannot be fetched.
228
+ """
229
+ import requests
230
+
231
+ out = {}
232
+ missing = []
233
+ for uid in uniref_ids:
234
+ if uid in _uniref_meta_cache:
235
+ out[uid] = _uniref_meta_cache[uid]
236
+ else:
237
+ missing.append(uid)
238
+
239
+ for uid in missing:
240
+ try:
241
+ r = requests.get(
242
+ f"https://rest.uniprot.org/uniref/{uid}.json",
243
+ params={"fields": "name,organism,count"},
244
+ timeout=5,
245
+ )
246
+ if r.status_code == 200:
247
+ data = r.json()
248
+ name = (data.get("name") or "").replace("Cluster: ", "")
249
+ rep = data.get("representativeMember", {}) or {}
250
+ org = (rep.get("organismName") or "").strip()
251
+ count = data.get("memberCount") or ""
252
+ parts = [p for p in (name, org, (f"{count} members" if count else "")) if p]
253
+ desc = " — ".join(parts) or "(no metadata)"
254
+ else:
255
+ desc = f"(HTTP {r.status_code})"
256
+ except Exception as e:
257
+ desc = f"(fetch error: {str(e).splitlines()[0][:60]})"
258
+ _uniref_meta_cache[uid] = desc
259
+ out[uid] = desc
260
+ return out
261
+
262
+
263
  def search_faiss(esm2_embedding, k=10):
264
+ """Search FAISS index for top-k UniRef50 neighbors. Returns a DataFrame
265
+ enriched with UniProt protein-name and organism metadata."""
266
  import faiss
267
 
268
  index, ids, _ = get_faiss()
 
270
  faiss.normalize_L2(q)
271
  scores, idxs = index.search(q, k)
272
 
273
+ # Decode ids first
274
+ hit_rows = []
275
  for rank, (score, i) in enumerate(zip(scores[0], idxs[0]), 1):
276
  if i < 0:
277
  continue
278
  uid = ids[i].decode() if isinstance(ids[i], (bytes, np.bytes_)) else str(ids[i])
279
+ hit_rows.append((rank, uid, float(score)))
280
+
281
+ meta = fetch_uniref_metadata([uid for _, uid, _ in hit_rows])
282
+ rows = []
283
+ for rank, uid, score in hit_rows:
284
  rows.append({
285
  "rank": rank,
286
  "uniref50_id": uid,
287
+ "cosine": round(score, 4),
288
+ "description": meta.get(uid, ""),
289
+ "uniprot": f"https://www.uniprot.org/uniref/{uid}",
290
  })
291
  return pd.DataFrame(rows)
292
 
 
577
  title="Functional Distance",
578
  css=".gradio-container { max-width: 100% !important; }"
579
  ) as demo:
580
+ gr.Markdown(
581
+ "# functional-distance\n"
582
+ "Functional distance prediction for proteins — pairwise distance (Twin, "
583
+ "GO-contrastive fine-tune of ESM2) and nearest-neighbor lookup (ESM2 against "
584
+ "UniRef50, or Twin against a curated CRISPR reference set)."
585
+ )
586
 
587
  with gr.Tab("Distance"):
588
  gr.Markdown(
 
717
  label="Example pairs (CRISPR / anti-CRISPR) — click to load",
718
  )
719
 
720
+ with gr.Tab("Lookup"):
721
+ gr.Markdown(
722
+ "### Nearest-neighbor lookup\n"
723
+ "Paste a protein sequence. The selected reference set returns the most similar "
724
+ "proteins by cosine similarity.\n"
725
+ "- **UniRef50** — 4.3 M GO-annotated UniRef50 proteins, **ESM2** embeddings (FAISS).\n"
726
+ "- **CRISPR reference** — curated Cas + anti-CRISPR set, **Twin-BP** embeddings. "
727
+ "*(Will be available once Step 6 build completes.)*"
728
+ )
729
  with gr.Row():
730
+ with gr.Column(scale=1, min_width=320):
731
+ lookup_seq_input = gr.Textbox(
732
  label="Protein Sequence",
733
  placeholder="Paste protein sequence (amino acids)...",
734
+ lines=6,
735
  value=EXAMPLE_PROTEIN,
736
+ info="FASTA or raw sequence; > 1022 aa is truncated"
737
  )
738
+ lookup_index_radio = gr.Radio(
739
+ choices=["UniRef50 (ESM2)", "CRISPR reference (Twin-BP)"],
740
+ value="UniRef50 (ESM2)",
741
+ label="Reference set",
742
  )
743
+ lookup_top_k = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="top-k")
744
+ lookup_btn = gr.Button("search", variant="primary")
745
+ with gr.Column(scale=2, min_width=400):
746
+ lookup_info = gr.Markdown()
747
+ lookup_hits = gr.Dataframe(
748
+ headers=["rank", "id", "cosine", "description", "link"],
749
+ datatype=["number", "str", "number", "str", "str"],
750
+ label="Nearest neighbors",
751
+ wrap=True,
752
  )
 
 
 
 
 
753
 
754
+ _lookup_empty = pd.DataFrame(columns=["rank", "id", "cosine", "description", "link"])
755
+
756
+ def _lookup_handler(sequence, index_choice, top_k):
757
+ sequence = strip_fasta_header(sequence.strip())
758
+ valid, err = validate_protein(sequence)
759
+ if not valid:
760
+ return f"**Error**: {err}", _lookup_empty
761
+ trunc_note = (f"> ⚠️ Query truncated from {len(sequence)} to {ESM2_MAX_LEN} aa "
762
+ f"(ESM2 limit).\n\n" if len(sequence) > ESM2_MAX_LEN else "")
763
+
764
+ if index_choice.startswith("UniRef50"):
765
+ try:
766
+ esm2_emb = embed_esm2(sequence)
767
+ h = search_faiss(esm2_emb, k=int(top_k))
768
+ df = pd.DataFrame({
769
+ "rank": h["rank"],
770
+ "id": h["uniref50_id"],
771
+ "cosine": h["cosine"],
772
+ "description": h["description"],
773
+ "link": h["uniprot"],
774
+ })
775
+ info = (f"{trunc_note}"
776
+ f"**UniRef50 (ESM2)** — top-{int(top_k)} nearest GO-annotated clusters by "
777
+ f"cosine similarity on L2-normalized 1280-dim ESM2 embeddings.")
778
+ return info, df
779
+ except Exception as e:
780
+ return (f"{trunc_note}**FAISS lookup failed**: {str(e).splitlines()[0]}\n\n"
781
+ f"The UniRef50 FAISS index may not be available yet.",
782
+ _lookup_empty)
783
+ else: # CRISPR reference
784
+ return (
785
+ f"{trunc_note}**CRISPR reference not yet available.**\n\n"
786
+ "Step 6 (`scripts/analyses/crispr/build_reference.py`) is queued / running on "
787
+ "SLURM. Once it finishes, the curated Cas + Acr embeddings will be packaged "
788
+ "into this Space and this option will return top-k hits with family / "
789
+ "organism / link columns.",
790
+ _lookup_empty,
791
+ )
792
 
793
+ lookup_btn.click(
794
+ lambda i: (f" Searching {i}…", _lookup_empty),
795
+ inputs=[lookup_index_radio],
796
+ outputs=[lookup_info, lookup_hits],
797
+ show_progress="hidden",
798
+ ).then(
799
+ _lookup_handler,
800
+ inputs=[lookup_seq_input, lookup_index_radio, lookup_top_k],
801
+ outputs=[lookup_info, lookup_hits],
802
+ api_name="lookup",
803
+ show_progress="minimal",
804
  )
805
 
 
 
 
 
 
 
 
806
  with gr.Tab("API"):
807
  gr.Markdown("""
808
  ### API
requirements.txt CHANGED
@@ -8,3 +8,4 @@ plotly>=5.18.0
8
  faiss-cpu>=1.7.4
9
  huggingface_hub>=0.23.0
10
  pandas>=2.0.0
 
 
8
  faiss-cpu>=1.7.4
9
  huggingface_hub>=0.23.0
10
  pandas>=2.0.0
11
+ requests>=2.31.0