Spaces:
Sleeping
Redesign Compare tab as Lookup; add UniRef metadata; update header
Browse files- Rename "Compare" to "Lookup" and strip it down: input sequence +
reference-set radio (UniRef50 / CRISPR) + top-k + button -> single
hits table. Heatmaps, stats bars, activation histogram, and .npy
downloads are gone; users who need those should use the Distance
tab or the API.
- Add fetch_uniref_metadata() that calls the UniProt uniref endpoint
to return "cluster name — organism — N members" for each hit. In-
memory cache; 5 s timeout; benign fallback string on failure. The
hits table now has a "description" column so users can see what
each UniRef50 cluster actually is, not just the accession.
- Update the app header to describe functional distance prediction
rather than "compare protein embeddings".
- Add requests to requirements.txt (explicit dep, for the metadata
fetcher).
- CRISPR reference option is exposed in the radio but the handler
currently returns a "not yet available" message since the Step 6
SLURM build is queued.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- app.py +137 -45
- requirements.txt +1 -0
|
@@ -216,8 +216,53 @@ def compute_distance(seq_a, seq_b, aspect):
|
|
| 216 |
l2 = float(torch.norm(a - b, p=2, dim=-1).item())
|
| 217 |
return {"l2": l2, "cos_sim": cos_sim, "cos_dist": 1.0 - cos_sim}
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
def search_faiss(esm2_embedding, k=10):
|
| 220 |
-
"""Search FAISS index for top-k UniRef50 neighbors. Returns a DataFrame
|
|
|
|
| 221 |
import faiss
|
| 222 |
|
| 223 |
index, ids, _ = get_faiss()
|
|
@@ -225,17 +270,23 @@ def search_faiss(esm2_embedding, k=10):
|
|
| 225 |
faiss.normalize_L2(q)
|
| 226 |
scores, idxs = index.search(q, k)
|
| 227 |
|
| 228 |
-
|
|
|
|
| 229 |
for rank, (score, i) in enumerate(zip(scores[0], idxs[0]), 1):
|
| 230 |
if i < 0:
|
| 231 |
continue
|
| 232 |
uid = ids[i].decode() if isinstance(ids[i], (bytes, np.bytes_)) else str(ids[i])
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
rows.append({
|
| 235 |
"rank": rank,
|
| 236 |
"uniref50_id": uid,
|
| 237 |
-
"cosine": round(
|
| 238 |
-
"
|
|
|
|
| 239 |
})
|
| 240 |
return pd.DataFrame(rows)
|
| 241 |
|
|
@@ -526,7 +577,12 @@ with gr.Blocks(
|
|
| 526 |
title="Functional Distance",
|
| 527 |
css=".gradio-container { max-width: 100% !important; }"
|
| 528 |
) as demo:
|
| 529 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
with gr.Tab("Distance"):
|
| 532 |
gr.Markdown(
|
|
@@ -661,56 +717,92 @@ with gr.Blocks(
|
|
| 661 |
label="Example pairs (CRISPR / anti-CRISPR) — click to load",
|
| 662 |
)
|
| 663 |
|
| 664 |
-
with gr.Tab("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
with gr.Row():
|
| 666 |
-
with gr.Column(scale=1, min_width=
|
| 667 |
-
|
| 668 |
label="Protein Sequence",
|
| 669 |
placeholder="Paste protein sequence (amino acids)...",
|
| 670 |
-
lines=
|
| 671 |
value=EXAMPLE_PROTEIN,
|
| 672 |
-
info="FASTA or raw sequence
|
| 673 |
)
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
|
|
|
| 677 |
)
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
|
|
|
|
|
|
|
|
|
| 684 |
)
|
| 685 |
-
btn = gr.Button("compare embeddings", variant="primary")
|
| 686 |
-
output = gr.Markdown()
|
| 687 |
-
with gr.Row():
|
| 688 |
-
esm2_download = gr.File(label="ESM2 .npy")
|
| 689 |
-
twin_download = gr.File(label="Twin .npy")
|
| 690 |
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 694 |
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
|
|
|
| 705 |
)
|
| 706 |
|
| 707 |
-
btn.click(
|
| 708 |
-
process,
|
| 709 |
-
inputs=[seq_input, top_k_slider, twin_aspect_radio],
|
| 710 |
-
outputs=[output, esm2_download, twin_download, esm2_heatmap, twin_heatmap, comparison_plot, distribution_plot, hits_table],
|
| 711 |
-
api_name="compare"
|
| 712 |
-
)
|
| 713 |
-
|
| 714 |
with gr.Tab("API"):
|
| 715 |
gr.Markdown("""
|
| 716 |
### API
|
|
|
|
| 216 |
l2 = float(torch.norm(a - b, p=2, dim=-1).item())
|
| 217 |
return {"l2": l2, "cos_sim": cos_sim, "cos_dist": 1.0 - cos_sim}
|
| 218 |
|
| 219 |
+
_uniref_meta_cache = {}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def fetch_uniref_metadata(uniref_ids):
|
| 223 |
+
"""Fetch cluster name + representative organism for a list of UniRef50 IDs.
|
| 224 |
+
|
| 225 |
+
Uses the UniProt uniref endpoint. Results are cached in memory for the life
|
| 226 |
+
of the Space process. Returns dict: {id -> "protein name — organism"}.
|
| 227 |
+
Falls back to "" for any id that cannot be fetched.
|
| 228 |
+
"""
|
| 229 |
+
import requests
|
| 230 |
+
|
| 231 |
+
out = {}
|
| 232 |
+
missing = []
|
| 233 |
+
for uid in uniref_ids:
|
| 234 |
+
if uid in _uniref_meta_cache:
|
| 235 |
+
out[uid] = _uniref_meta_cache[uid]
|
| 236 |
+
else:
|
| 237 |
+
missing.append(uid)
|
| 238 |
+
|
| 239 |
+
for uid in missing:
|
| 240 |
+
try:
|
| 241 |
+
r = requests.get(
|
| 242 |
+
f"https://rest.uniprot.org/uniref/{uid}.json",
|
| 243 |
+
params={"fields": "name,organism,count"},
|
| 244 |
+
timeout=5,
|
| 245 |
+
)
|
| 246 |
+
if r.status_code == 200:
|
| 247 |
+
data = r.json()
|
| 248 |
+
name = (data.get("name") or "").replace("Cluster: ", "")
|
| 249 |
+
rep = data.get("representativeMember", {}) or {}
|
| 250 |
+
org = (rep.get("organismName") or "").strip()
|
| 251 |
+
count = data.get("memberCount") or ""
|
| 252 |
+
parts = [p for p in (name, org, (f"{count} members" if count else "")) if p]
|
| 253 |
+
desc = " — ".join(parts) or "(no metadata)"
|
| 254 |
+
else:
|
| 255 |
+
desc = f"(HTTP {r.status_code})"
|
| 256 |
+
except Exception as e:
|
| 257 |
+
desc = f"(fetch error: {str(e).splitlines()[0][:60]})"
|
| 258 |
+
_uniref_meta_cache[uid] = desc
|
| 259 |
+
out[uid] = desc
|
| 260 |
+
return out
|
| 261 |
+
|
| 262 |
+
|
| 263 |
def search_faiss(esm2_embedding, k=10):
|
| 264 |
+
"""Search FAISS index for top-k UniRef50 neighbors. Returns a DataFrame
|
| 265 |
+
enriched with UniProt protein-name and organism metadata."""
|
| 266 |
import faiss
|
| 267 |
|
| 268 |
index, ids, _ = get_faiss()
|
|
|
|
| 270 |
faiss.normalize_L2(q)
|
| 271 |
scores, idxs = index.search(q, k)
|
| 272 |
|
| 273 |
+
# Decode ids first
|
| 274 |
+
hit_rows = []
|
| 275 |
for rank, (score, i) in enumerate(zip(scores[0], idxs[0]), 1):
|
| 276 |
if i < 0:
|
| 277 |
continue
|
| 278 |
uid = ids[i].decode() if isinstance(ids[i], (bytes, np.bytes_)) else str(ids[i])
|
| 279 |
+
hit_rows.append((rank, uid, float(score)))
|
| 280 |
+
|
| 281 |
+
meta = fetch_uniref_metadata([uid for _, uid, _ in hit_rows])
|
| 282 |
+
rows = []
|
| 283 |
+
for rank, uid, score in hit_rows:
|
| 284 |
rows.append({
|
| 285 |
"rank": rank,
|
| 286 |
"uniref50_id": uid,
|
| 287 |
+
"cosine": round(score, 4),
|
| 288 |
+
"description": meta.get(uid, ""),
|
| 289 |
+
"uniprot": f"https://www.uniprot.org/uniref/{uid}",
|
| 290 |
})
|
| 291 |
return pd.DataFrame(rows)
|
| 292 |
|
|
|
|
| 577 |
title="Functional Distance",
|
| 578 |
css=".gradio-container { max-width: 100% !important; }"
|
| 579 |
) as demo:
|
| 580 |
+
gr.Markdown(
|
| 581 |
+
"# functional-distance\n"
|
| 582 |
+
"Functional distance prediction for proteins — pairwise distance (Twin, "
|
| 583 |
+
"GO-contrastive fine-tune of ESM2) and nearest-neighbor lookup (ESM2 against "
|
| 584 |
+
"UniRef50, or Twin against a curated CRISPR reference set)."
|
| 585 |
+
)
|
| 586 |
|
| 587 |
with gr.Tab("Distance"):
|
| 588 |
gr.Markdown(
|
|
|
|
| 717 |
label="Example pairs (CRISPR / anti-CRISPR) — click to load",
|
| 718 |
)
|
| 719 |
|
| 720 |
+
with gr.Tab("Lookup"):
|
| 721 |
+
gr.Markdown(
|
| 722 |
+
"### Nearest-neighbor lookup\n"
|
| 723 |
+
"Paste a protein sequence. The selected reference set returns the most similar "
|
| 724 |
+
"proteins by cosine similarity.\n"
|
| 725 |
+
"- **UniRef50** — 4.3 M GO-annotated UniRef50 proteins, **ESM2** embeddings (FAISS).\n"
|
| 726 |
+
"- **CRISPR reference** — curated Cas + anti-CRISPR set, **Twin-BP** embeddings. "
|
| 727 |
+
"*(Will be available once Step 6 build completes.)*"
|
| 728 |
+
)
|
| 729 |
with gr.Row():
|
| 730 |
+
with gr.Column(scale=1, min_width=320):
|
| 731 |
+
lookup_seq_input = gr.Textbox(
|
| 732 |
label="Protein Sequence",
|
| 733 |
placeholder="Paste protein sequence (amino acids)...",
|
| 734 |
+
lines=6,
|
| 735 |
value=EXAMPLE_PROTEIN,
|
| 736 |
+
info="FASTA or raw sequence; > 1022 aa is truncated"
|
| 737 |
)
|
| 738 |
+
lookup_index_radio = gr.Radio(
|
| 739 |
+
choices=["UniRef50 (ESM2)", "CRISPR reference (Twin-BP)"],
|
| 740 |
+
value="UniRef50 (ESM2)",
|
| 741 |
+
label="Reference set",
|
| 742 |
)
|
| 743 |
+
lookup_top_k = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="top-k")
|
| 744 |
+
lookup_btn = gr.Button("search", variant="primary")
|
| 745 |
+
with gr.Column(scale=2, min_width=400):
|
| 746 |
+
lookup_info = gr.Markdown()
|
| 747 |
+
lookup_hits = gr.Dataframe(
|
| 748 |
+
headers=["rank", "id", "cosine", "description", "link"],
|
| 749 |
+
datatype=["number", "str", "number", "str", "str"],
|
| 750 |
+
label="Nearest neighbors",
|
| 751 |
+
wrap=True,
|
| 752 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
|
| 754 |
+
_lookup_empty = pd.DataFrame(columns=["rank", "id", "cosine", "description", "link"])
|
| 755 |
+
|
| 756 |
+
def _lookup_handler(sequence, index_choice, top_k):
|
| 757 |
+
sequence = strip_fasta_header(sequence.strip())
|
| 758 |
+
valid, err = validate_protein(sequence)
|
| 759 |
+
if not valid:
|
| 760 |
+
return f"**Error**: {err}", _lookup_empty
|
| 761 |
+
trunc_note = (f"> ⚠️ Query truncated from {len(sequence)} to {ESM2_MAX_LEN} aa "
|
| 762 |
+
f"(ESM2 limit).\n\n" if len(sequence) > ESM2_MAX_LEN else "")
|
| 763 |
+
|
| 764 |
+
if index_choice.startswith("UniRef50"):
|
| 765 |
+
try:
|
| 766 |
+
esm2_emb = embed_esm2(sequence)
|
| 767 |
+
h = search_faiss(esm2_emb, k=int(top_k))
|
| 768 |
+
df = pd.DataFrame({
|
| 769 |
+
"rank": h["rank"],
|
| 770 |
+
"id": h["uniref50_id"],
|
| 771 |
+
"cosine": h["cosine"],
|
| 772 |
+
"description": h["description"],
|
| 773 |
+
"link": h["uniprot"],
|
| 774 |
+
})
|
| 775 |
+
info = (f"{trunc_note}"
|
| 776 |
+
f"**UniRef50 (ESM2)** — top-{int(top_k)} nearest GO-annotated clusters by "
|
| 777 |
+
f"cosine similarity on L2-normalized 1280-dim ESM2 embeddings.")
|
| 778 |
+
return info, df
|
| 779 |
+
except Exception as e:
|
| 780 |
+
return (f"{trunc_note}**FAISS lookup failed**: {str(e).splitlines()[0]}\n\n"
|
| 781 |
+
f"The UniRef50 FAISS index may not be available yet.",
|
| 782 |
+
_lookup_empty)
|
| 783 |
+
else: # CRISPR reference
|
| 784 |
+
return (
|
| 785 |
+
f"{trunc_note}**CRISPR reference not yet available.**\n\n"
|
| 786 |
+
"Step 6 (`scripts/analyses/crispr/build_reference.py`) is queued / running on "
|
| 787 |
+
"SLURM. Once it finishes, the curated Cas + Acr embeddings will be packaged "
|
| 788 |
+
"into this Space and this option will return top-k hits with family / "
|
| 789 |
+
"organism / link columns.",
|
| 790 |
+
_lookup_empty,
|
| 791 |
+
)
|
| 792 |
|
| 793 |
+
lookup_btn.click(
|
| 794 |
+
lambda i: (f"⏳ Searching {i}…", _lookup_empty),
|
| 795 |
+
inputs=[lookup_index_radio],
|
| 796 |
+
outputs=[lookup_info, lookup_hits],
|
| 797 |
+
show_progress="hidden",
|
| 798 |
+
).then(
|
| 799 |
+
_lookup_handler,
|
| 800 |
+
inputs=[lookup_seq_input, lookup_index_radio, lookup_top_k],
|
| 801 |
+
outputs=[lookup_info, lookup_hits],
|
| 802 |
+
api_name="lookup",
|
| 803 |
+
show_progress="minimal",
|
| 804 |
)
|
| 805 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
with gr.Tab("API"):
|
| 807 |
gr.Markdown("""
|
| 808 |
### API
|
|
@@ -8,3 +8,4 @@ plotly>=5.18.0
|
|
| 8 |
faiss-cpu>=1.7.4
|
| 9 |
huggingface_hub>=0.23.0
|
| 10 |
pandas>=2.0.0
|
|
|
|
|
|
| 8 |
faiss-cpu>=1.7.4
|
| 9 |
huggingface_hub>=0.23.0
|
| 10 |
pandas>=2.0.0
|
| 11 |
+
requests>=2.31.0
|