genomenet's picture
Clarify UniRef50 lookup histogram
600e4c0
"""
Functional Distance - Compare ESM2 vs Twin protein embeddings
HuggingFace Spaces App
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import html
import tempfile
import gradio as gr
import numpy as np
import pandas as pd
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# Model config
ESM2_MODEL = "esm2_t33_650M_UR50D" # 650M params, 1280-dim
ESM2_DIM = 1280
TWIN_DIM = 1024 # 2 * projection_dim (512), two-tower concat
# FAISS index config (UniRef50 GO-annotated protein clusters)
ESM2_FAISS_REPO_ID = os.environ.get("ESM2_FAISS_REPO_ID", os.environ.get("FAISS_REPO_ID", "genomenet/esm2-uniref50-faiss"))
TWIN_BP_FAISS_REPO_ID = os.environ.get("TWIN_BP_FAISS_REPO_ID", "genomenet/twin-uniref50-faiss")
FAISS_REPO_ID = ESM2_FAISS_REPO_ID # backwards-compatible name used in older logs/env configs
FAISS_IDS_FILE = "ids.npy"
FAISS_METADATA_FILE = "metadata.json"
FAISS_NPROBE = int(os.environ.get("FAISS_NPROBE", "32"))
FAISS_CONFIGS = {
"esm2": {
"repo_id": ESM2_FAISS_REPO_ID,
"index_file": "esm2_uniref50.index",
"label": "ESM2 baseline",
"dim_info": "1280-dim ESM2",
},
"twin-bp": {
"repo_id": TWIN_BP_FAISS_REPO_ID,
"index_file": "twin_uniref50.index",
"label": "genomenet-twin (BP)",
"dim_info": "1024-dim Twin-BP",
},
}
# Twin model config (3 aspect-specific checkpoints in one HF model repo)
TWIN_REPO_ID = os.environ.get("TWIN_REPO_ID", "genomenet/twin-point-1024")
TWIN_CHECKPOINT_FILES = {
"BP": "bp_cp_best.pt", # Biological Process
"CC": "cc_cp_best.pt", # Cellular Component
"MF": "mf_cp_best.pt", # Molecular Function
}
TWIN_DEFAULT_ASPECT = os.environ.get("TWIN_DEFAULT_ASPECT", "BP")
TWIN_ESM_BACKBONE = os.environ.get("TWIN_ESM_BACKBONE", "facebook/esm2_t33_650M_UR50D")
# Model cache
_esm2_model = None
_esm2_alphabet = None
# Only one aspect cached at a time (each Twin is ~2.7 GB on CPU, can't fit all 3 on a cpu-basic Space)
_twin_cache = {"aspect": None, "model": None, "seq_len": None}
_faiss_cache = {"name": None, "index": None, "ids": None, "metadata": None}
def get_esm2():
"""Load ESM2 model."""
global _esm2_model, _esm2_alphabet
if _esm2_model is None:
import esm
print(f"Loading ESM2 model: {ESM2_MODEL}...")
_esm2_model, _esm2_alphabet = esm.pretrained.load_model_and_alphabet(ESM2_MODEL)
_esm2_model = _esm2_model.eval()
if torch.cuda.is_available():
_esm2_model = _esm2_model.cuda()
print("ESM2 loaded.")
return _esm2_model, _esm2_alphabet
def get_twin(aspect=None):
"""Download + load the fine-tuned Twin model for the requested GO aspect.
Only one aspect is kept in memory at a time — switching aspects evicts the
previous model (each is ~2.7 GB; three won't fit on cpu-basic).
"""
global _twin_cache
aspect = (aspect or TWIN_DEFAULT_ASPECT).upper()
if aspect not in TWIN_CHECKPOINT_FILES:
raise ValueError(f"Unknown aspect {aspect!r}; expected one of {list(TWIN_CHECKPOINT_FILES)}")
if _twin_cache["aspect"] == aspect and _twin_cache["model"] is not None:
return _twin_cache["model"], _twin_cache["seq_len"]
import gc
import torch
from huggingface_hub import hf_hub_download
from twin_model import load_twin_model
# Evict any previously loaded aspect to free ~2.7 GB before loading the next.
if _twin_cache["model"] is not None:
print(f"Evicting Twin/{_twin_cache['aspect']} to load Twin/{aspect}...")
_twin_cache["model"] = None
_twin_cache["seq_len"] = None
_twin_cache["aspect"] = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
filename = TWIN_CHECKPOINT_FILES[aspect]
print(f"Downloading Twin/{aspect} checkpoint ({filename}) from {TWIN_REPO_ID}...")
ckpt_path = hf_hub_download(repo_id=TWIN_REPO_ID, filename=filename, repo_type="model")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, seq_len, emb_dim = load_twin_model(ckpt_path, device, TWIN_ESM_BACKBONE)
if emb_dim != TWIN_DIM:
print(f" WARN Twin/{aspect} output dim is {emb_dim}, expected {TWIN_DIM}")
_twin_cache["aspect"] = aspect
_twin_cache["model"] = model
_twin_cache["seq_len"] = seq_len
return model, seq_len
def get_faiss(index_name="esm2"):
"""Download + load the requested FAISS index and UniRef50 id mapping."""
global _faiss_cache
if _faiss_cache["name"] == index_name and _faiss_cache["index"] is not None:
return _faiss_cache["index"], _faiss_cache["ids"], _faiss_cache["metadata"]
if index_name not in FAISS_CONFIGS:
raise ValueError(f"Unknown FAISS index {index_name!r}; expected one of {list(FAISS_CONFIGS)}")
import gc
import faiss
import json
from huggingface_hub import snapshot_download
if _faiss_cache["index"] is not None:
print(f"Evicting FAISS/{_faiss_cache['name']} to load FAISS/{index_name}...")
_faiss_cache = {"name": None, "index": None, "ids": None, "metadata": None}
gc.collect()
cfg = FAISS_CONFIGS[index_name]
print(f"Downloading FAISS/{index_name} index from {cfg['repo_id']}...")
local_dir = snapshot_download(
repo_id=cfg["repo_id"],
repo_type="dataset",
allow_patterns=[cfg["index_file"], FAISS_IDS_FILE, FAISS_METADATA_FILE],
)
print(f"Loading FAISS index from {local_dir}...")
index = faiss.read_index(os.path.join(local_dir, cfg["index_file"]))
try:
ivf = faiss.extract_index_ivf(index)
ivf.nprobe = FAISS_NPROBE
nprobe = ivf.nprobe
except Exception:
nprobe = "n/a"
ids = np.load(os.path.join(local_dir, FAISS_IDS_FILE))
metadata = None
meta_path = os.path.join(local_dir, FAISS_METADATA_FILE)
if os.path.exists(meta_path):
with open(meta_path) as f:
metadata = json.load(f)
_faiss_cache = {"name": index_name, "index": index, "ids": ids, "metadata": metadata}
print(f"FAISS/{index_name} ready: {index.ntotal:,} vectors, nprobe={nprobe}")
return index, ids, metadata
def get_device():
return "cuda" if torch.cuda.is_available() else "cpu"
# Valid amino acids
AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY")
AMINO_ACIDS_EXTENDED = AMINO_ACIDS | set("XBZJOU") # Include ambiguous
ESM2_MAX_LEN = 1022 # ESM2 position embedding limit; longer sequences are truncated
def validate_protein(sequence):
"""Validate protein sequence. Does NOT reject long sequences — both embedders
truncate internally; we surface a note in the output instead."""
if not sequence or len(sequence.strip()) == 0:
return False, "Sequence is empty"
sequence = sequence.upper().replace(" ", "").replace("\n", "")
invalid = set(sequence) - AMINO_ACIDS_EXTENDED
if invalid:
return False, f"Invalid characters: {invalid}"
if len(sequence) < 10:
return False, f"Sequence too short: {len(sequence)} < 10 aa"
return True, ""
def strip_fasta_header(text):
"""Remove FASTA headers."""
lines = text.strip().split('\n')
return ''.join(l.strip() for l in lines if not l.startswith('>')).upper()
@torch.no_grad()
def embed_esm2(sequence):
"""Compute ESM2 embedding (mean-pooled). Truncates to ESM2_MAX_LEN."""
model, alphabet = get_esm2()
batch_converter = alphabet.get_batch_converter()
device = get_device()
# ESM2 position embeddings cap at 1022; longer sequences must be truncated.
if len(sequence) > ESM2_MAX_LEN:
sequence = sequence[:ESM2_MAX_LEN]
data = [("protein", sequence)]
_, _, batch_tokens = batch_converter(data)
batch_tokens = batch_tokens.to(device)
results = model(batch_tokens, repr_layers=[model.num_layers], return_contacts=False)
representations = results["representations"][model.num_layers]
# Mean pool over sequence (exclude BOS/EOS)
seq_len = len(sequence)
embedding = representations[0, 1:seq_len+1, :].mean(dim=0).cpu().numpy()
return embedding
@torch.no_grad()
def embed_twin(sequence, aspect=None):
"""Compute Twin embedding for the given GO aspect (BP/CC/MF)."""
from twin_model import ensure_aa_sequence, preprocess_sequences_batch
model, seq_len = get_twin(aspect)
device = next(model.parameters()).device
cleaned = ensure_aa_sequence(sequence)
input_ids = preprocess_sequences_batch([cleaned], seq_len, device)
combined = model(input_ids) # (1, TWIN_DIM)
return combined[0].float().cpu().numpy()
@torch.no_grad()
def compute_distance(seq_a, seq_b, aspect):
"""Twin-model pairwise distance (native trained task).
Returns a dict with L2 distance on L2-normalized embeddings (training
convention), cosine similarity, and cosine distance.
"""
import torch.nn.functional as F
from twin_model import ensure_aa_sequence, preprocess_sequences_batch
model, seq_len = get_twin(aspect)
device = next(model.parameters()).device
seqs = [ensure_aa_sequence(seq_a), ensure_aa_sequence(seq_b)]
input_ids = preprocess_sequences_batch(seqs, seq_len, device)
emb = model(input_ids) # (2, TWIN_DIM)
a = F.normalize(emb[0:1], p=2, dim=-1)
b = F.normalize(emb[1:2], p=2, dim=-1)
cos_sim = float((a * b).sum().item())
l2 = float(torch.norm(a - b, p=2, dim=-1).item())
return {"l2": l2, "cos_sim": cos_sim, "cos_dist": 1.0 - cos_sim}
_CRISPR_REF_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "crispr_reference")
_CRISPR_FILES = {
"Twin-BP": "embeddings.npy", # 1024-dim, from Step 6 Twin-BP build
"Twin-MF": "embeddings_mf.npy", # 1024-dim, best-fit aspect for CRISPR
"ESM2 baseline": "embeddings_esm2.npy", # 1280-dim, ESM2 mean-pool
}
_crispr_ref = {m: {"emb": None, "meta": None} for m in _CRISPR_FILES}
def get_crispr_reference(method="Twin-BP"):
"""Load the curated CRISPR reference embeddings + metadata (lazy, cached per method)."""
if _crispr_ref[method]["emb"] is not None:
return _crispr_ref[method]["emb"], _crispr_ref[method]["meta"]
emb_path = os.path.join(_CRISPR_REF_DIR, _CRISPR_FILES[method])
meta_path = os.path.join(_CRISPR_REF_DIR, "metadata.csv")
if not (os.path.exists(emb_path) and os.path.exists(meta_path)):
raise FileNotFoundError(f"CRISPR reference ({method}) not packaged at {_CRISPR_REF_DIR}")
emb = np.load(emb_path).astype(np.float32)
# L2-normalize once so cosine similarity = dot product
norms = np.linalg.norm(emb, axis=1, keepdims=True) + 1e-9
_crispr_ref[method]["emb"] = emb / norms
_crispr_ref[method]["meta"] = pd.read_csv(meta_path)
print(f"CRISPR reference ({method}) loaded: {emb.shape[0]} proteins, dim={emb.shape[1]}")
return _crispr_ref[method]["emb"], _crispr_ref[method]["meta"]
def _cell_text(value):
if value is None:
return ""
try:
if pd.isna(value):
return ""
except Exception:
pass
return str(value)
def _crispr_display_id(row):
"""Readable ID for mixed Cas/Acr rows, whose internal ids are prefixed."""
source_acc = _cell_text(row.get("source_accession", ""))
original = _cell_text(row.get("original_acc", ""))
acc = _cell_text(row.get("acc", ""))
if source_acc:
return source_acc
if original:
return original.replace("cas__", "").replace("acr__", "")
return acc.replace("cas__", "").replace("acr__", "")
def _crispr_protein_link(row):
source_acc = _cell_text(row.get("source_accession", ""))
original = _cell_text(row.get("original_acc", ""))
acc = _cell_text(row.get("acc", ""))
clean = source_acc or original or acc
clean = clean.replace("cas__", "").replace("acr__", "")
if "__" in clean:
clean = clean.split("__")[-1]
if not clean or clean.startswith("local__"):
return ""
if clean.startswith(("WP_", "YP_", "NP_", "XP_", "AKG", "ERJ")) or clean.endswith(".1"):
return f"https://www.ncbi.nlm.nih.gov/protein/{clean}"
return f"https://www.uniprot.org/uniprotkb/{clean}"
def search_crispr(query_emb, k=25, negate=False, method="Twin-BP"):
"""Cosine search over the curated CRISPR reference for the chosen method."""
ref_emb, meta = get_crispr_reference(method)
q = np.asarray(query_emb, dtype=np.float32).reshape(-1)
q = q / (np.linalg.norm(q) + 1e-9)
sims = ref_emb @ q # (N,)
if negate:
order = np.argsort(sims)[:k]
else:
order = np.argsort(-sims)[:k]
rows = []
for rank, idx in enumerate(order, 1):
m = meta.iloc[idx]
family = _cell_text(m.get("family", ""))
typ = _cell_text(m.get("type", ""))
organism = _cell_text(m.get("organism", ""))
rows.append({
"rank": rank,
"uniref50_id": _crispr_display_id(m),
"ref_id": m["acc"],
"cosine": round(float(sims[idx]), 4),
"description": f"{family} · {typ} · {organism}",
"uniprot": _crispr_protein_link(m),
})
return pd.DataFrame(rows)
_CRISPR_UMAP_FILES = {
"ESM2 baseline": "umap_coords_esm2.npy",
"Twin-MF": "umap_coords_mf.npy",
"Twin-BP": "umap_coords_bp.npy",
}
_CRISPR_UMAP3D_FILES = {
"ESM2 baseline": "umap_coords_3d_esm2.npy",
"Twin-MF": "umap_coords_3d_mf.npy",
"Twin-BP": "umap_coords_3d_bp.npy",
}
_crispr_umap_cache = {}
_crispr_umap3d_cache = {}
def get_crispr_umap(method):
"""Load pre-computed UMAP 2D coordinates for the CRISPR reference set."""
if method in _crispr_umap_cache:
return _crispr_umap_cache[method]
fname = _CRISPR_UMAP_FILES.get(method)
if fname is None:
return None
p = os.path.join(_CRISPR_REF_DIR, fname)
if not os.path.exists(p):
return None
xy = np.load(p).astype(np.float32)
_crispr_umap_cache[method] = xy
return xy
def get_crispr_umap_3d(method):
"""Load pre-computed UMAP 3D coordinates for the CRISPR reference set."""
if method in _crispr_umap3d_cache:
return _crispr_umap3d_cache[method]
fname = _CRISPR_UMAP3D_FILES.get(method)
if fname is None:
return None
p = os.path.join(_CRISPR_REF_DIR, fname)
if not os.path.exists(p):
return None
xyz = np.load(p).astype(np.float32)
_crispr_umap3d_cache[method] = xyz
return xyz
_UNKNOWN_RESULTS_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "data", "unknown_proteins", "results.json"
)
_UNKNOWN_FASTA_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "data", "unknown_proteins", "sequences.fasta"
)
_unknown_results = None
_unknown_sequences = None
def get_unknown_results():
"""Load the pre-computed characterization results for uncharacterized
proteins (see scripts/analyses/crispr/precompute_unknowns.py)."""
global _unknown_results
if _unknown_results is not None:
return _unknown_results
import json
if not os.path.exists(_UNKNOWN_RESULTS_PATH):
_unknown_results = {"n_queries": 0, "records": [], "generated_at": None}
else:
with open(_UNKNOWN_RESULTS_PATH) as f:
_unknown_results = json.load(f)
return _unknown_results
def get_unknown_sequences():
"""Load unknown-protein query sequences by FASTA id for candidate tables."""
global _unknown_sequences
if _unknown_sequences is not None:
return _unknown_sequences
seqs = {}
if not os.path.exists(_UNKNOWN_FASTA_PATH):
_unknown_sequences = seqs
return _unknown_sequences
current = None
chunks = []
with open(_UNKNOWN_FASTA_PATH) as f:
for line in f:
line = line.strip()
if not line:
continue
if line.startswith(">"):
if current is not None:
seqs[current] = "".join(chunks)
current = line[1:].split(maxsplit=1)[0]
chunks = []
else:
chunks.append(line)
if current is not None:
seqs[current] = "".join(chunks)
_unknown_sequences = seqs
return _unknown_sequences
def _wrap_sequence_for_html(seq, width=80):
seq = seq or ""
return "\n".join(seq[i:i + width] for i in range(0, len(seq), width))
def _crispr_verdict(top_k_rows, top_cos, threshold):
"""Return (label, confidence) given ranked hits."""
from collections import Counter
if not top_k_rows:
return "Unknown — no hits", "low"
if top_cos < threshold:
return (f"Unlikely CRISPR-related (top cosine {top_cos:+.3f} < threshold "
f"{threshold:.2f})"), "low"
fams = [r["family"] for r in top_k_rows]
c = Counter(fams)
top_fam, top_votes = c.most_common(1)[0]
ratio = top_votes / len(fams)
if ratio >= 0.6:
conf = "high" if (ratio >= 0.8 and top_cos > threshold + 0.15) else "medium"
return f"Predicted: **{top_fam}** ({top_votes}/{len(fams)} of top hits)", conf
groups = [r["group"] for r in top_k_rows]
gc = Counter(groups)
top_group, top_gvotes = gc.most_common(1)[0]
if top_gvotes / len(groups) >= 0.7:
tag = "Cas protein" if top_group == "cas" else "anti-CRISPR protein"
return (f"CRISPR-related: **{tag}**; specific family unclear "
f"(top mix: {dict(c.most_common(3))})"), "medium"
return (f"CRISPR-related, but signal is mixed "
f"(top family votes: {dict(c.most_common(3))})"), "medium"
_uniref_meta_cache = {}
def fetch_uniref_metadata(uniref_ids):
"""Fetch cluster name + representative organism for a list of UniRef50 IDs.
Uses the UniProt uniref endpoint. Results are cached in memory for the life
of the Space process. Returns dict: {id -> "protein name — organism"}.
Falls back to "" for any id that cannot be fetched.
"""
import requests
out = {}
missing = []
for uid in uniref_ids:
if uid in _uniref_meta_cache:
out[uid] = _uniref_meta_cache[uid]
else:
missing.append(uid)
for uid in missing:
desc = "(no metadata)"
try:
r = requests.get(
f"https://rest.uniprot.org/uniref/{uid}.json",
params={"fields": "name,organism,count"},
timeout=5,
)
if r.status_code == 200:
data = r.json()
name = (data.get("name") or "").replace("Cluster: ", "")
rep = data.get("representativeMember", {}) or {}
org = (rep.get("organismName") or "").strip()
count = data.get("memberCount") or ""
parts = [p for p in (name, org, (f"{count} members" if count else "")) if p]
desc = " — ".join(parts) or "(no metadata)"
elif r.status_code == 404 and uid.startswith("UniRef50_UPI"):
# UniParc representative — no UniProtKB entry, but UniParc has xrefs with
# protein name + organism from RefSeq/EMBL/etc.
upi = uid[len("UniRef50_"):]
r2 = requests.get(
f"https://rest.uniprot.org/uniparc/{upi}.json", timeout=5
)
if r2.status_code == 200:
xrefs = (r2.json() or {}).get("uniParcCrossReferences", []) or []
name = org = ""
for cr in xrefs:
if not name and cr.get("proteinName"):
name = cr["proteinName"]
if not org and (cr.get("organism") or {}).get("scientificName"):
org = cr["organism"]["scientificName"]
if name and org:
break
parts = [p for p in (name, org, "(UniParc)") if p]
desc = " — ".join(parts) if (name or org) else "(UniParc entry — no annotation)"
else:
desc = "(UniParc fetch failed)"
else:
desc = "(not in UniRef — cluster may have been updated)"
except Exception as e:
desc = f"(fetch error: {str(e).splitlines()[0][:60]})"
_uniref_meta_cache[uid] = desc
out[uid] = desc
return out
def search_faiss(query_embedding, k=10, negate=False, fetch_metadata=True, index_name="esm2"):
"""Search FAISS index for top-k UniRef50 neighbors.
If `negate` is True, searches with -q: returns the most anti-correlated
(least similar) proteins. Returned cosine values are still reported with
respect to the original q (i.e., we flip the sign after search).
When `fetch_metadata` is False, skips the per-hit UniProt API call (useful
when you need a larger pool of cosine values for a distribution plot).
"""
import faiss
index, ids, _ = get_faiss(index_name)
q = np.asarray(query_embedding, dtype=np.float32)[None, :]
faiss.normalize_L2(q)
if negate:
q = -q
scores, idxs = index.search(q, k)
hit_rows = []
for rank, (score, i) in enumerate(zip(scores[0], idxs[0]), 1):
if i < 0:
continue
uid = ids[i].decode() if isinstance(ids[i], (bytes, np.bytes_)) else str(ids[i])
cos_real = -float(score) if negate else float(score)
hit_rows.append((rank, uid, cos_real))
meta = fetch_uniref_metadata([uid for _, uid, _ in hit_rows]) if fetch_metadata else {}
rows = []
for rank, uid, cos in hit_rows:
rows.append({
"rank": rank,
"uniref50_id": uid,
"cosine": round(cos, 4),
"description": meta.get(uid, ""),
"uniprot": f"https://www.uniprot.org/uniref/{uid}",
})
return pd.DataFrame(rows)
def compute_stats(embedding):
"""Compute embedding statistics."""
emb = np.array(embedding)
l2_norm = np.linalg.norm(emb)
mean_act = np.mean(emb)
std_act = np.std(emb)
sparsity = np.mean(np.abs(emb) < 0.1)
hist, _ = np.histogram(emb, bins=50, density=True)
hist = hist[hist > 0]
entropy = -np.sum(hist * np.log(hist + 1e-10))
return {
'l2_norm': float(l2_norm),
'mean': float(mean_act),
'std': float(std_act),
'sparsity': float(sparsity),
'entropy': float(entropy),
'dim': len(emb)
}
def create_embedding_heatmap(embedding, title, cols=32):
"""Create heatmap of embedding."""
n_dims = len(embedding)
rows = int(np.ceil(n_dims / cols))
padded = np.full(rows * cols, np.nan)
padded[:n_dims] = embedding
grid = padded.reshape(rows, cols)
vmax = max(abs(np.nanmin(embedding)), abs(np.nanmax(embedding)), 0.01)
fig, ax = plt.subplots(figsize=(10, max(3, rows * 0.3)))
im = ax.imshow(grid, cmap='RdBu_r', vmin=-vmax, vmax=vmax, aspect='auto')
plt.colorbar(im, ax=ax, shrink=0.8)
ax.set_title(f'{title} ({n_dims} dims)')
ax.set_xlabel('Dimension')
plt.tight_layout()
return fig
def create_comparison_plot(esm2_stats, twin_stats):
"""Create side-by-side stats comparison."""
metrics = ['L2 Norm', 'Mean', 'Std', 'Sparsity', 'Entropy']
esm2_vals = [esm2_stats['l2_norm'], esm2_stats['mean'], esm2_stats['std'],
esm2_stats['sparsity'], esm2_stats['entropy']]
twin_vals = [twin_stats['l2_norm'], twin_stats['mean'], twin_stats['std'],
twin_stats['sparsity'], twin_stats['entropy']]
fig = make_subplots(rows=1, cols=1)
fig.add_trace(go.Bar(
name='ESM2',
x=metrics,
y=esm2_vals,
marker_color='#3b82f6',
text=[f'{v:.3f}' for v in esm2_vals],
textposition='outside'
))
fig.add_trace(go.Bar(
name='Twin',
x=metrics,
y=twin_vals,
marker_color='#10b981',
text=[f'{v:.3f}' for v in twin_vals],
textposition='outside'
))
fig.update_layout(
barmode='group',
height=350,
legend=dict(orientation='h', y=1.1),
margin=dict(l=40, r=20, t=50, b=40)
)
return fig
def create_distribution_plot(esm2_emb, twin_emb):
"""Compare activation distributions."""
fig = go.Figure()
fig.add_trace(go.Histogram(
x=esm2_emb,
name='ESM2',
opacity=0.7,
marker_color='#3b82f6',
nbinsx=50
))
fig.add_trace(go.Histogram(
x=twin_emb,
name='Twin',
opacity=0.7,
marker_color='#10b981',
nbinsx=50
))
fig.update_layout(
barmode='overlay',
xaxis_title='Activation value',
yaxis_title='Count',
height=300,
legend=dict(orientation='h', y=1.1),
margin=dict(l=40, r=20, t=40, b=40)
)
return fig
# CRISPR-associated + anti-CRISPR sequences (UniProt, reviewed/representative).
# This app is framed around CRISPR biology: Twin is a general GO-contrastive model,
# but the intended use case is annotation of Cas and anti-CRISPR (Acr) genes.
# Sequences longer than 1022 aa (SpCas9, SaCas9, FnCas12a, LshCas13a) are truncated
# by the model's seq_len=1024; the N-terminal region still carries most of the signal.
_CAS1_ECOLI = ( # Q46901, E. coli, 502 aa — adaptation, universal
"MNLLIDNWIPVRPRNGGKVQIINLQSLYCSRDQWRLSLPRDDMELAALALLVCIGQIIAPAKDDVEFRHRIMNPLTEDEF"
"QQLIAPWIDMFYLNHAEHPFMQTKGVKANDVTPMEKLLAGVSGATNCAFVNQPGQGEALCGGCTAIALFNQANQAPGFGG"
"GFKSGLRGGTPVTTFVRGIDLRSTVLLNVLTLPRLQKQFPNESHTENQPTWIKPIKSNESIPASSIGFVRGLFWQPAHIE"
"LCDPIGIGKCSCCGQESNLRYTGFLKEKFTFTVNGLWPHPHSPCLVTVKKGEVEEKFLAFTTSAPSWTQISRVVVDKIIQ"
"NENGNRVAAVVNQFRNIAPQSPLELIMGGYRNNQASILERRHDVLMFNQGWQQYGNVINEIVTVGLGYKTALRKALYTFA"
"EGFKNKDFKGAGVSVHETAERHFYRQSELLIPDVLANVNFSQADEVIADLRDKLHQLCEMLFNQSVAPYAHHPKLISTLA"
"LARATLYKHLRELKPQGGPSNG"
)
_CAS2_ECOLI = ( # P45956, E. coli, 94 aa — adaptation, universal
"MSMLVVVTENVPPRLRGRLAIWLLEVRAGVYVGDVSAKIREMIWEQIAGLAEEGNVVMAWATNTETGFEFQTFGLNRRTP"
"VDLDGLRLVSFLPV"
)
_CAS3_ECOLI = ( # P38036, E. coli, 888 aa — Type I interference nuclease/helicase
"MEPFKYICHYWGKSSKSLTKGNDIHLLIYHCLDVAAVADCWWDQSVVLQNTFCRNEMLSKQRVKAWLLFFIALHDIGKFD"
"IRFQYKSAESWLKLNPATPSLNGPSTQMCRKFNHGAAGLYWFNQDSLSEQSLGDFFSFFDAAPHPYESWFPWVEAVTGHH"
"GFILHSQDQDKSRWEMPASLASYAAQDKQAREEWISVLEALFLTPAGLSINDIPPDCSSLLAGFCSLADWLGSWTTTNTF"
"LFNEDAPSDINALRTYFQDRQQDASRVLELSGLVSNKRCYEGVHALLDNGYQPRQLQVLVDALPVAPGLTVIEAPTGSGK"
"TETALAYAWKLIDQQIADSVIFALPTQATANAMLTRMEASASHLFSSPNLILAHGNSRFNHLFQSIKSRAITEQGQEEAW"
"VQCCQWLSQSNKKVFLGQIGVCTIDQVLISVLPVKHRFIRGLGIGRSVLIVDEVHAYDTYMNGLLEAVLKAQADVGGSVI"
"LLSATLPMKQKQKLLDTYGLHTDPVENNSAYPLINWRGVNGAQRFDLLAHPEQLPPRFSIQPEPICLADMLPDLTMLERM"
"IAAANAGAQVCLICNLVDVAQVCYQRLKELNNTQVDIDLFHARFTLNDRREKENRVISNFGKNGKRNVGRILVATQVVEQ"
"SLDVDFDWLITQHCPADLLFQRLGRLHRHHRKYRPAGFEIPVATILLPDGEGYGRHEHIYSNVRVMWRTQQHIEELNGAS"
"LFFPDAYRQWLDSIYDDAEMDEPEWVGNGMDKFESAECEKRFKARKVLQWAEEYSLQDNDETILAVTRDGEMSLPLLPYV"
"QTSSGKQLLDGQVYEDLSHEQQYEALALNRVNVPFTWKRSFSEVVDEDGLLWLEGKQNLDGWVWQGNSIVITYTGDEGMT"
"RVIPANPK"
)
_SPCAS9 = ( # Q99ZW2, S. pyogenes, 1368 aa — Type II-A effector (truncated by model to 1022 aa)
"MDKKYSIGLDIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRIC"
"YLQEIFSNEMAKVDDSFFHRLEESFLVEEDKKHERHPIFGNIVDEVAYHEKYPTIYHLRKKLVDSTDKADLRLIYLALAH"
"MIKFRGHFLIEGDLNPDNSDVDKLFIQLVQTYNQLFEENPINASGVDAKAILSARLSKSRRLENLIAQLPGEKKNGLFGN"
"LIALSLGLTPNFKSNFDLAEDAKLQLSKDTYDDDLDNLLAQIGDQYADLFLAAKNLSDAILLSDILRVNTEITKAPLSAS"
"MIKRYDEHHQDLTLLKALVRQQLPEKYKEIFFDQSKNGYAGYIDGGASQEEFYKFIKPILEKMDGTEELLVKLNREDLLR"
"KQRTFDNGSIPHQIHLGELHAILRRQEDFYPFLKDNREKIEKILTFRIPYYVGPLARGNSRFAWMTRKSEETITPWNFEE"
"VVDKGASAQSFIERMTNFDKNLPNEKVLPKHSLLYEYFTVYNELTKVKYVTEGMRKPAFLSGEQKKAIVDLLFKTNRKVT"
"VKQLKEDYFKKIECFDSVEISGVEDRFNASLGTYHDLLKIIKDKDFLDNEENEDILEDIVLTLTLFEDREMIEERLKTYA"
"HLFDDKVMKQLKRRRYTGWGRLSRKLINGIRDKQSGKTILDFLKSDGFANRNFMQLIHDDSLTFKEDIQKAQVSGQGDSL"
"HEHIANLAGSPAIKKGILQTVKVVDELVKVMGRHKPENIVIEMARENQTTQKGQKNSRERMKRIEEGIKELGSQILKEHP"
"VENTQLQNEKLYLYYLQNGRDMYVDQELDINRLSDYDVDHIVPQSFLKDDSIDNKVLTRSDKNRGKSDNVPSEEVVKKMK"
"NYWRQLLNAKLITQRKFDNLTKAERGGLSELDKAGFIKRQLVETRQITKHVAQILDSRMNTKYDENDKLIREVKVITLKS"
"KLVSDFRKDFQFYKVREINNYHHAHDAYLNAVVGTALIKKYPKLESEFVYGDYKVYDVRKMIAKSEQEIGKATAKYFFYS"
"NIMNFFKTEITLANGEIRKRPLIETNGETGEIVWDKGRDFATVRKVLSMPQVNIVKKTEVQTGGFSKESILPKRNSDKLI"
"ARKKDWDPKKYGGFDSPTVAYSVLVVAKVEKGKSKKLKSVKELLGITIMERSSFEKNPIDFLEAKGYKEVKKDLIIKLPK"
"YSLFELENGRKRMLASAGELQKGNELALPSKYVNFLYLASHYEKLKGSPEDNEQKQLFVEQHKHYLDEIIEQISEFSKRV"
"ILADANLDKVLSAYNKHRDKPIREQAENIIHLFTLTNLGAPAAFKYFDTTIDRKRYTSTKEVLDATLIHQSITGLYETRI"
"DLSQLGGD"
)
_SACAS9 = ( # J7RUA5, S. aureus, 1053 aa — Type II-A effector (smaller Cas9 ortholog)
"MKRNYILGLDIGITSVGYGIIDYETRDVIDAGVRLFKEANVENNEGRRSKRGARRLKRRRRHRIQRVKKLLFDYNLLTDH"
"SELSGINPYEARVKGLSQKLSEEEFSAALLHLAKRRGVHNVNEVEEDTGNELSTKEQISRNSKALEEKYVAELQLERLKK"
"DGEVRGSINRFKTSDYVKEAKQLLKVQKAYHQLDQSFIDTYIDLLETRRTYYEGPGEGSPFGWKDIKEWYEMLMGHCTYF"
"PEELRSVKYAYNADLYNALNDLNNLVITRDENEKLEYYEKFQIIENVFKQKKKPTLKQIAKEILVNEEDIKGYRVTSTGK"
"PEFTNLKVYHDIKDITARKEIIENAELLDQIAKILTIYQSSEDIQEELTNLNSELTQEEIEQISNLKGYTGTHNLSLKAI"
"NLILDELWHTNDNQIAIFNRLKLVPKKVDLSQQKEIPTTLVDDFILSPVVKRSFIQSIKVINAIIKKYGLPNDIIIELAR"
"EKNSKDAQKMINEMQKRNRQTNERIEEIIRTTGKENAKYLIEKIKLHDMQEGKCLYSLEAIPLEDLLNNPFNYEVDHIIP"
"RSVSFDNSFNNKVLVKQEENSKKGNRTPFQYLSSSDSKISYETFKKHILNLAKGKGRISKTKKEYLLEERDINRFSVQKD"
"FINRNLVDTRYATRGLMNLLRSYFRVNNLDVKVKSINGGFTSFLRRKWKFKKERNKGYKHHAEDALIIANADFIFKEWKK"
"LDKAKKVMENQMFEEKQAESMPEIETEQEYKEIFITPHQIKHIKDFKDYKYSHRVDKKPNRELINDTLYSTRKDDKGNTL"
"IVNNLNGLYDKDNDKLKKLINKSPEKLLMYHHDPQTYQKLKLIMEQYGDEKNPLYKYYEETGNYLTKYSKKDNGPVIKKI"
"KYYGNKLNAHLDITDDYPNSRNKVVKLSLKPYRFDVYLDNGVYKFVTVKNLDVIKKENYYEVNSKCYEEAKKLKKISNQA"
"EFIASFYNNDLIKINGELYRVIGVNNDLLNRIEVNMIDITYREYLENMNDKRPPRIIKTIASKTQSIKKYSTDILGNLYE"
"VKSKKHPQIIKKG"
)
_FNCAS12A = ( # A0Q7Q2, F. novicida, 1300 aa — Type V-A effector (Cpf1, truncated by model)
"MSIYQEFVNKYSLSKTLRFELIPQGKTLENIKARGLILDDEKRAKDYKKAKQIIDKYHQFFIEEILSSVCISEDLLQNYS"
"DVYFKLKKSDDDNLQKDFKSAKDTIKKQISEYIKDSEKFKNLFNQNLIDAKKGQESDLILWLKQSKDNGIELFKANSDIT"
"DIDEALEIIKSFKGWTTYFKGFHENRKNVYSSNDIPTSIIYRIVDDNLPKFLENKAKYESLKDKAPEAINYEQIKKDLAE"
"ELTFDIDYKTSEVNQRVFSLDEVFEIANFNNYLNQSGITKFNTIIGGKFVNGENTKRKGINEYINLYSQQINDKTLKKYK"
"MSVLFKQILSDTESKSFVIDKLEDDSDVVTTMQSFYEQIAAFKTVEEKSIKETLSLLFDDLKAQKLDLSKIYFKNDKSLT"
"DLSQQVFDDYSVIGTAVLEYITQQIAPKNLDNPSKKEQELIAKKTEKAKYLSLETIKLALEEFNKHRDIDKQCRFEEILA"
"NFAAIPMIFDEIAQNKDNLAQISIKYQNQGKKDLLQASAEDDVKAIKDLLDQTNNLLHKLKIFHISQSEDKANILDKDEH"
"FYLVFEECYFELANIVPLYNKIRNYITQKPYSDEKFKLNFENSTLANGWDKNKEPDNTAILFIKDDKYYLGVMNKKNNKI"
"FDDKAIKENKGEGYKKIVYKLLPGANKMLPKVFFSAKSIKFYNPSEDILRIRNHSTHTKNGSPQKGYEKFEFNIEDCRKF"
"IDFYKQSISKHPEWKDFGFRFSDTQRYNSIDEFYREVENQGYKLTFENISESYIDSVVNQGKLYLFQIYNKDFSAYSKGR"
"PNLHTLYWKALFDERNLQDVVYKLNGEAELFYRKQSIPKKITHPAKEAIANKNKDNPKKESVFEYDLIKDKRFTEDKFFF"
"HCPITINFKSSGANKFNDEINLLLKEKANDVHILSIDRGERHLAYYTLVDGKGNIIKQDTFNIIGNDRMKTNYHDKLAAI"
"EKDRDSARKDWKKINNIKEMKEGYLSQVVHEIAKLVIEYNAIVVFEDLNFGFKRGRFKVEKQVYQKLEKMLIEKLNYLVF"
"KDNEFDKTGGVLRAYQLTAPFETFKKMGKQTGIIYYVPAGFTSKICPVTGFVNQLYPKYESVSKSQEFFSKFDKICYNLD"
"KGYFEFSFDYKNFGDKAAKGKWTIASFGSRLINFRNSDKNHNWDTREVYPTKELEKLLKDYSIEYGHGECIKAAICGESD"
"KKFFAKLTSVLNTILQMRNSKTGTELDYLISPVADVNGNFFDSRQAPKNMPQDADANGAYHIGLKGLMLLGRIKNNQEGK"
"KLNLVIKNEEYFEFVQNRNN"
)
_LSHCAS13A = ( # P0DOC6, L. shahii, 1389 aa — Type VI-A effector (RNA-targeting, truncated)
"MGNLFGHKRWYEVRDKKDFKIKRKVKVKRNYDGNKYILNINENNNKEKIDNNKFIRKYINYKKNDNILKEFTRKFHAGNI"
"LFKLKGKEGIIRIENNDDFLETEEVVLYIEAYGKSEKLKALGITKKKIIDEAIRQGITKDDKKIEIKRQENEEEIEIDIR"
"DEYTNKTLNDCSIILRIIENDELETKKSIYEIFKNINMSLYKIIEKIIENETEKVFENRYYEEHLREKLLKDDKIDVILT"
"NFMEIREKIKSNLEILGFVKFYLNVGGDKKKSKNKKMLVEKILNINVDLTVEDIADFVIKELEFWNITKRIEKVKKVNNE"
"FLEKRRNRTYIKSYVLLDKHEKFKIERENKKDKIVKFFVENIKNNSIKEKIEKILAEFKIDELIKKLEKELKKGNCDTEI"
"FGIFKKHYKVNFDSKKFSKKSDEEKELYKIIYRYLKGRIEKILVNEQKVRLKKMEKIEIEKILNESILSEKILKRVKQYT"
"LEHIMYLGKLRHNDIDMTTVNTDDFSRLHAKEELDLELITFFASTNMELNKIFSRENINNDENIDFFGGDREKNYVLDKK"
"ILNSKIKIIRDLDFIDNKNNITNNFIRKFTKIGTNERNRILHAISKERDLQGTQDDYNKVINIIQNLKISDEEVSKALNL"
"DVVFKDKKNIITKINDIKISEENNNDIKYLPSFSKVLPEILNLYRNNPKNEPFDTIETEKIVLNALIYVNKELYKKLILE"
"DDLEENESKNIFLQELKKTLGNIDEIDENIIENYYKNAQISASKGNNKAIKKYQKKVIECYIGYLRKNYEELFDFSDFKM"
"NIQEIKKQIKDINDNKTYERITVKTSDKTIVINDDFEYIISIFALLNSNAVINKIRNRFFATSVWLNTSEYQNIIDILDE"
"IMQLNTLRNECITENWNLNLEEFIQKMKEIEKDFDDFKIQTKKEIFNNYYEDIKNNILTEFKDDINGCDVLEKKLEKIVI"
"FDDETKFEIDKKSNILQDEQRKLSNINKKDLKKKVDQYIKDKDQEIKSKILCRIIFNSDFLKKYKKEIDNLIEDMESENE"
"NKFQEIYYPKERKNELYIYKKNLFLNIGNPNFDKIYGLISNDIKMADAKFLFNIDGKNIRKNKISEIDAILKNLNDKLNG"
"YSKEYKEKYIKKLKENDDFFAKNIQNKNYKSFEKDYNRVSEYKKIRDLVEFNYLNKIESYLIDINWKLAIQMARFERDMH"
"YIVNGLRELGIIKLSGYNTGISRAYPKRNGSDGFYTTTAYYKFFDEESYKKFEKICYGFGIDLSENSEINKPENESIRNY"
"ISHFYIVRNPFADYSIAEQIDRVSNLLSYSTRYNNSTYASVFEVFKKDVNLDYDELKKKFKLIGNNDILERLMKPKKVSV"
"LELESYNSDYIKNLIIELLTKIENTNDTL"
)
_ACRIIA4 = ( # A0A0E0UT28, Listeria monocytogenes phage, 87 aa — anti-Cas9 (inhibits SpCas9)
"MNISELIREIKNKDYAVRLEGTDDNSITKLIIDVDNDGNEYVISESKNESIAEKFASTFKNGWNKEYEDEEEFYNDMQSI"
"ILKSELN"
)
# Default sequence shown in the Compare tab and as Protein A in the Distance tab
EXAMPLE_PROTEIN = _SPCAS9
def process(sequence: str, top_k: int = 10, twin_aspect: str = "BP"):
"""Process protein sequence, compare embeddings, and search FAISS."""
sequence = strip_fasta_header(sequence.strip())
empty_df = pd.DataFrame(columns=["rank", "uniref50_id", "cosine", "uniprot"])
valid, error = validate_protein(sequence)
if not valid:
return f"**Error**: {error}", None, None, None, None, None, None, empty_df
# Compute embeddings
esm2_emb = embed_esm2(sequence)
twin_emb = embed_twin(sequence, aspect=twin_aspect)
# Compute stats
esm2_stats = compute_stats(esm2_emb)
twin_stats = compute_stats(twin_emb)
# Save embeddings
esm2_path = os.path.join(tempfile.gettempdir(), "esm2_embedding.npy")
twin_path = os.path.join(tempfile.gettempdir(), "twin_embedding.npy")
np.save(esm2_path, esm2_emb)
np.save(twin_path, twin_emb)
# Nearest neighbors in UniRef50 (GO-annotated subset).
# FAISS index may not be uploaded yet (SLURM job still building) — in that
# case, surface the error in the hits table but keep all embeddings/plots.
try:
hits_df = search_faiss(esm2_emb, k=int(top_k))
except Exception as e:
print(f"FAISS search failed: {e}")
hits_df = pd.DataFrame([{
"rank": 0,
"uniref50_id": "(FAISS index not available)",
"cosine": 0.0,
"uniprot": str(e).splitlines()[0][:200],
}])
trunc_note = (f"\n> ⚠️ Sequence truncated from {len(sequence)} to {ESM2_MAX_LEN} aa "
f"(ESM2 position-embedding limit). Twin also truncates internally.\n"
if len(sequence) > ESM2_MAX_LEN else "")
summary = f"""### Results
{trunc_note}
| | ESM2 | Twin ({twin_aspect}) |
|---|---|---|
| Dimension | {esm2_stats['dim']} | {twin_stats['dim']} |
| L2 Norm | {esm2_stats['l2_norm']:.2f} | {twin_stats['l2_norm']:.2f} |
| Entropy | {esm2_stats['entropy']:.2f} | {twin_stats['entropy']:.2f} |
| Sparsity | {esm2_stats['sparsity']:.1%} | {twin_stats['sparsity']:.1%} |
Sequence: {len(sequence)} aa
"""
# Create visualizations
esm2_heatmap = create_embedding_heatmap(esm2_emb, "ESM2 Embedding")
twin_heatmap = create_embedding_heatmap(twin_emb, f"Twin Embedding ({twin_aspect})")
comparison_plot = create_comparison_plot(esm2_stats, twin_stats)
distribution_plot = create_distribution_plot(esm2_emb, twin_emb)
return summary, esm2_path, twin_path, esm2_heatmap, twin_heatmap, comparison_plot, distribution_plot, hits_df
# Build interface
# When the URL contains a hash like #benchmark or ?tab=benchmark, auto-click
# the matching tab button after the app loads. Allows deep-linking to any tab.
_TAB_HASH_JS = """
function() {
function pickTab(name) {
if (!name) return;
const want = name.toLowerCase();
const tryOnce = () => {
const btns = document.querySelectorAll('button[role="tab"]');
for (const b of btns) {
if ((b.textContent || "").trim().toLowerCase().startsWith(want)) {
b.click();
return true;
}
}
return false;
};
let attempts = 0;
const t = setInterval(() => {
if (tryOnce() || ++attempts > 20) clearInterval(t);
}, 200);
}
const hash = (window.location.hash || "").replace(/^#/, "");
const params = new URLSearchParams(window.location.search || "");
pickTab(hash || params.get("tab"));
}
"""
with gr.Blocks(
title="Functional Distance",
css=".gradio-container { max-width: 100% !important; }",
js=_TAB_HASH_JS,
) as demo:
gr.Markdown(
"# functional-distance\n"
"Functional distance prediction for proteins — pairwise distance (Twin, "
"GO-contrastive fine-tune of ESM2) and nearest-neighbor lookup (ESM2 or "
"Twin-BP against UniRef50, plus CRISPR-focused lookup against a curated "
"Cas/Acr reference set)."
)
with gr.Tab("Distance"):
gr.Markdown(
"### Twin pairwise distance\n"
"Compute the Twin model's **native trained distance** between two protein "
"sequences under the selected GO aspect. "
"**L2 distance** is on L2-normalized 1024-dim embeddings (the training convention, "
"∈ [0, 2]); equivalent to `sqrt(2 - 2·cos_sim)`. Interpret relative to other pairs, "
"not as an absolute threshold."
)
with gr.Row():
with gr.Column():
dist_seq_a = gr.Textbox(
label="Protein A",
placeholder="Paste protein sequence (amino acids)...",
lines=5,
value=EXAMPLE_PROTEIN,
)
dist_seq_b = gr.Textbox(
label="Protein B",
placeholder="Paste protein sequence (amino acids)...",
lines=5,
value=_SACAS9,
info="Default: SaCas9 (S. aureus). Pair with the SpCas9 default → ortholog test."
)
dist_aspect_radio = gr.Radio(
choices=["BP", "CC", "MF"],
value=TWIN_DEFAULT_ASPECT,
label="Twin GO aspect",
)
dist_btn = gr.Button("compute distance", variant="primary")
with gr.Column():
dist_output = gr.HTML()
def _distance_bar(value, v_min, v_max, labels, gradient_stops, fmt="{:+.4f}"):
"""Horizontal bar with a position marker, labels below."""
pct = max(0.0, min(100.0, (value - v_min) / (v_max - v_min) * 100.0))
return f"""
<div style='margin:18px 0;'>
<div style='display:flex;justify-content:space-between;margin-bottom:6px;font-size:14px;'>
<span style='font-weight:600;'>{labels['title']}</span>
<span style='font-family:ui-monospace,monospace;font-weight:600;'>{fmt.format(value)}</span>
</div>
<div style='position:relative;height:22px;background:linear-gradient(to right,{gradient_stops});border-radius:4px;'>
<div style='position:absolute;left:{pct:.2f}%;top:-5px;width:3px;height:32px;background:#111;transform:translateX(-50%);box-shadow:0 0 0 2px #fff;'></div>
</div>
<div style='display:flex;justify-content:space-between;font-size:11px;color:#888;margin-top:4px;'>
<span>{labels['left']}</span>
<span>{labels['mid']}</span>
<span>{labels['right']}</span>
</div>
</div>"""
def _distance_handler(seq_a, seq_b, aspect):
seq_a = strip_fasta_header(seq_a.strip())
seq_b = strip_fasta_header(seq_b.strip())
for name, seq in (("A", seq_a), ("B", seq_b)):
valid, err = validate_protein(seq)
if not valid:
return f"<div style='color:#dc2626;font-weight:600;'>Error in sequence {name}: {err}</div>"
trunc_notes = []
for name, seq in (("A", seq_a), ("B", seq_b)):
if len(seq) > ESM2_MAX_LEN:
trunc_notes.append(f"Protein {name} truncated from {len(seq)}{ESM2_MAX_LEN} aa")
trunc_html = (f"<div style='background:#fff7ed;border-left:3px solid #f97316;"
f"padding:8px 12px;margin:8px 0;font-size:12px;color:#9a3412;'>"
f"⚠️ {'; '.join(trunc_notes)} (ESM2 position-embedding limit).</div>"
if trunc_notes else "")
d = compute_distance(seq_a, seq_b, aspect)
# Green = similar, red = dissimilar
l2_bar = _distance_bar(
d["l2"], 0.0, 2.0,
labels={"title": "L2 distance (L2-normalized)",
"left": "0 · identical", "mid": "√2 ≈ 1.41 · orthogonal", "right": "2 · opposite"},
gradient_stops="#4ade80 0%,#facc15 50%,#f87171 100%",
fmt="{:.4f}",
)
cos_bar = _distance_bar(
d["cos_dist"], 0.0, 2.0,
labels={"title": "cosine distance (1 − cos_sim)",
"left": "0 · identical", "mid": "1 · orthogonal", "right": "2 · opposite"},
gradient_stops="#4ade80 0%,#facc15 50%,#f87171 100%",
fmt="{:.4f}",
)
return (
f"<h3 style='margin-top:0;'>Twin/{aspect} distance</h3>"
f"{trunc_html}"
f"{l2_bar}{cos_bar}"
f"<p style='font-size:12px;color:#666;margin-top:16px;'>"
f"cosine similarity = {d['cos_sim']:+.4f}&nbsp; · &nbsp;"
f"sequences: A = {len(seq_a)} aa, B = {len(seq_b)} aa</p>"
)
dist_btn.click(
lambda a: "<div style='padding:16px;color:#666;'>⏳ Computing Twin distance…"
"<br><span style='font-size:12px;'>First run of an aspect: ~15 s to load the checkpoint. Subsequent calls: &lt;1 s.</span></div>",
inputs=[dist_aspect_radio],
outputs=[dist_output],
show_progress="hidden",
).then(
_distance_handler,
inputs=[dist_seq_a, dist_seq_b, dist_aspect_radio],
outputs=[dist_output],
api_name="distance",
show_progress="minimal",
)
# CRISPR-themed example pairs. Click to populate the two sequence boxes.
# Sequences declared as module-level constants (_SPCAS9, _CAS1_ECOLI, ...)
# since they're also used as Compare-tab defaults.
gr.Examples(
examples=[
# [seq_a, seq_b, aspect]
[_SPCAS9, _SPCAS9, "BP"], # sanity: identical -> L2 ~ 0
[_SPCAS9, _SACAS9, "BP"], # orthologs (both Type II-A Cas9)
[_CAS1_ECOLI, _CAS2_ECOLI, "BP"], # adaptation complex partners
[_SPCAS9, _FNCAS12A, "BP"], # Type II vs Type V (both DNA-targeting)
[_CAS1_ECOLI, _CAS3_ECOLI, "BP"], # Type I adaptation vs interference
[_SPCAS9, _LSHCAS13A, "BP"], # different substrate: DNA vs RNA
[_ACRIIA4, _SPCAS9, "BP"], # anti-CRISPR vs its target
],
inputs=[dist_seq_a, dist_seq_b, dist_aspect_radio],
example_labels=[
"Identical (sanity) — SpCas9 / SpCas9",
"Orthologs — SpCas9 / SaCas9 (both Type II-A)",
"Adaptation partners — Cas1 / Cas2 (E. coli)",
"Different CRISPR types — SpCas9 (II) / FnCas12a (V) — DNA-targeting",
"Same pathway, different role — Cas1 (adaptation) / Cas3 (interference), Type I",
"Different substrate — SpCas9 (DNA) / LshCas13a (RNA)",
"Anti-CRISPR vs target — AcrIIA4 / SpCas9",
],
label="Example pairs (CRISPR / anti-CRISPR) — click to load",
)
with gr.Tab("Characterize Unknown Proteins"):
gr.Markdown(
"### Characterize unknown proteins\n"
"Paste a protein sequence, choose an embedding, then compare it either to the "
"large UniRef50 FAISS background or to the curated CRISPR protein reference. "
"For the CRISPR reference this page now also reports a heuristic family verdict "
"and places the query into the same Cas/Acr reference projection used by the benchmark.\n"
"- **UniRef50** — broad nearest-neighbour lookup across the GO-annotated UniRef50 subset "
"using either ESM2 or the Twin-BP index.\n"
"- **CRISPR curated** — mixed **254-protein** reference: 175 Pfam-verified Cas proteins "
"plus 79 Anti-CRISPRdb v3/local Acr proteins."
)
with gr.Row():
with gr.Column(scale=1, min_width=320):
lookup_seq_input = gr.Textbox(
label="Protein Sequence",
placeholder="Paste protein sequence (amino acids)...",
lines=6,
value=EXAMPLE_PROTEIN,
info="FASTA or raw sequence; > 1022 aa is truncated"
)
lookup_method_radio = gr.Radio(
choices=["ESM2 baseline", "genomenet-twin (MF)", "genomenet-twin (BP)"],
value="ESM2 baseline",
label="Method",
)
lookup_index_radio = gr.Radio(
choices=["UniRef50", "CRISPR curated"],
value="UniRef50",
label="Reference set",
)
lookup_btn = gr.Button("search", variant="primary")
with gr.Column(scale=2, min_width=400):
lookup_info = gr.Markdown()
lookup_verdict = gr.Markdown()
lookup_plot = gr.Plot(label="CRISPR reference placement")
lookup_hits = gr.HTML()
_lookup_empty_html = ""
def _hits_bar_cell(cos, width=180):
"""Red -> yellow -> green gradient; black marker at the cosine position (0..1)."""
pct = max(0.0, min(100.0, float(cos) * 100.0))
return (
f"<div style='display:flex;align-items:center;gap:10px;'>"
f"<div style='position:relative;width:{width}px;height:12px;"
f"background:linear-gradient(to right,#f87171 0%,#facc15 50%,#4ade80 100%);"
f"border-radius:3px;'>"
f"<div style='position:absolute;left:{pct:.2f}%;top:-3px;width:2px;height:18px;"
f"background:#111;box-shadow:0 0 0 2px #fff;transform:translateX(-50%);'></div>"
f"</div>"
f"<span style='font-family:ui-monospace,monospace;font-size:12px;min-width:52px;'>{cos:.4f}</span>"
f"</div>"
)
def _bar_scale_normalized(cos, lo, hi):
"""Position cos on a scale with explicit lo/hi endpoints (for the per-row marker)."""
if hi <= lo:
return 50.0
return max(0.0, min(100.0, (float(cos) - lo) / (hi - lo) * 100.0))
def _hits_table_section(df, title, lo, hi):
"""Render a hits DataFrame as an HTML table. Bars use a shared [lo, hi] scale."""
if df.empty:
return f"<h4>{title}</h4><p style='color:#888;'>(empty)</p>"
rows = []
for _, r in df.iterrows():
pct = _bar_scale_normalized(r["cosine"], lo, hi)
bar = (
f"<div style='display:flex;align-items:center;gap:10px;'>"
f"<div style='position:relative;width:180px;height:12px;"
f"background:linear-gradient(to right,#f87171 0%,#facc15 50%,#4ade80 100%);border-radius:3px;'>"
f"<div style='position:absolute;left:{pct:.2f}%;top:-3px;width:2px;height:18px;"
f"background:#111;box-shadow:0 0 0 2px #fff;transform:translateX(-50%);'></div>"
f"</div>"
f"<span style='font-family:ui-monospace,monospace;font-size:12px;min-width:60px;'>{r['cosine']:+.4f}</span>"
f"</div>"
)
link = (f"<a href='{r['link']}' target='_blank' rel='noopener noreferrer' "
f"style='color:#2563eb;text-decoration:none;'>↗</a>"
if str(r.get("link", "")).startswith("http") else "")
desc = str(r.get("description", "") or "")
if len(desc) > 140:
desc = desc[:140] + "…"
rows.append(
f"<tr>"
f"<td style='padding:5px 10px;color:#555;'>{r['rank']}</td>"
f"<td style='padding:5px 10px;font-family:ui-monospace,monospace;white-space:nowrap;'>{r['id']}</td>"
f"<td style='padding:5px 10px;'>{bar}</td>"
f"<td style='padding:5px 10px;font-size:13px;color:#374151;'>{desc}</td>"
f"<td style='padding:5px 10px;'>{link}</td>"
f"</tr>"
)
return (
f"<h4 style='margin:18px 0 6px 0;'>{title}</h4>"
f"<table style='width:100%;border-collapse:collapse;font-size:13px;'>"
f"<thead><tr style='text-align:left;border-bottom:1px solid #e5e7eb;color:#6b7280;font-size:12px;'>"
f"<th style='padding:5px 10px;'>#</th><th style='padding:5px 10px;'>ID</th>"
f"<th style='padding:5px 10px;'>similarity</th><th style='padding:5px 10px;'>description</th>"
f"<th style='padding:5px 10px;'></th></tr></thead>"
f"<tbody>{''.join(rows)}</tbody></table>"
)
def _distribution_svg(top_display, bot_display, top_context, bot_context,
width=760, height=150, n_bins=64):
"""Stacked histogram of every retrieved top/bottom FAISS hit.
Metadata is fetched only for the displayed table rows, but the
histogram includes the full retrieved extreme pool.
"""
all_v = list(top_display) + list(bot_display) + list(top_context) + list(bot_context)
if not all_v:
return ""
lo, hi = min(all_v), max(all_v)
pad = (hi - lo) * 0.05 if hi > lo else 0.01
x_lo, x_hi = lo - pad, hi + pad
def _bin(v):
return min(n_bins - 1, max(0, int((v - x_lo) / (x_hi - x_lo) * n_bins)))
top_counts = [0] * n_bins
top_context_counts = [0] * n_bins
bot_counts = [0] * n_bins
bot_context_counts = [0] * n_bins
for v in top_display:
top_counts[_bin(v)] += 1
for v in top_context:
top_context_counts[_bin(v)] += 1
for v in bot_display:
bot_counts[_bin(v)] += 1
for v in bot_context:
bot_context_counts[_bin(v)] += 1
max_c = max(
t + tc + b + bc
for t, tc, b, bc in zip(top_counts, top_context_counts, bot_counts, bot_context_counts)
) or 1
bar_w = width / n_bins
usable_h = height - 42
bars = []
for i in range(n_bins):
y = height - 28
segments = [
(bot_context_counts[i] / max_c * usable_h, "#fecaca"),
(bot_counts[i] / max_c * usable_h, "#ef4444"),
(top_context_counts[i] / max_c * usable_h, "#bbf7d0"),
(top_counts[i] / max_c * usable_h, "#16a34a"),
]
for h, colour in segments:
if h > 0:
bars.append(
f"<rect x='{i*bar_w:.2f}' y='{y-h:.2f}' width='{max(bar_w-1, 1):.2f}' "
f"height='{h:.2f}' fill='{colour}' />"
)
y -= h
total_n = len(all_v)
axis_y = height - 27
labels = (
f"<line x1='0' y1='{axis_y}' x2='{width}' y2='{axis_y}' stroke='#9ca3af' stroke-width='1' />"
f"<text x='0' y='{height-9}' fill='#6b7280' font-size='10' font-family='sans-serif'>{x_lo:+.2f}</text>"
f"<text x='{width/2}' y='{height-9}' text-anchor='middle' fill='#6b7280' font-size='10' font-family='sans-serif'>cosine similarity across retrieved extreme hits (n={total_n})</text>"
f"<text x='{width}' y='{height-9}' text-anchor='end' fill='#6b7280' font-size='10' font-family='sans-serif'>{x_hi:+.2f}</text>"
)
legend = (
f"<div style='display:flex;align-items:center;gap:14px;flex-wrap:wrap;font-size:11px;color:#6b7280;margin-top:4px;'>"
f"<span style='display:inline-flex;align-items:center;gap:5px;'>"
f"<span style='width:10px;height:10px;background:#16a34a;display:inline-block;border-radius:2px;'></span>"
f"top {len(top_display)} with metadata/table</span>"
f"<span style='display:inline-flex;align-items:center;gap:5px;'>"
f"<span style='width:10px;height:10px;background:#bbf7d0;display:inline-block;border-radius:2px;'></span>"
f"next {len(top_context)} top hits</span>"
f"<span style='display:inline-flex;align-items:center;gap:5px;'>"
f"<span style='width:10px;height:10px;background:#ef4444;display:inline-block;border-radius:2px;'></span>"
f"bottom {len(bot_display)} with metadata/table</span>"
f"<span style='display:inline-flex;align-items:center;gap:5px;'>"
f"<span style='width:10px;height:10px;background:#fecaca;display:inline-block;border-radius:2px;'></span>"
f"next {len(bot_context)} anti-correlated hits</span>"
f"</div>"
f"<div style='font-size:11px;color:#6b7280;margin-top:3px;'>"
f"All retrieved hits are plotted here; metadata is fetched only for the table rows. "
f"This is an extreme-neighbour view, not a full scan of every UniRef50 cluster.</div>"
)
return (
f"<div style='margin:8px 0 14px 0;'>"
f"<svg width='100%' viewBox='0 0 {width} {height}' preserveAspectRatio='none' "
f"style='display:block;background:#fafafa;border:1px solid #e5e7eb;border-radius:4px;'>"
f"{''.join(bars)}{labels}</svg>"
f"{legend}</div>"
)
DISPLAY_K = 25
POOL_K = 500 # larger FAISS fetch (no metadata) for the distribution histogram
def _lookup_method_spec(method_choice):
if method_choice.startswith("genomenet-twin"):
aspect = "MF" if "(MF)" in method_choice else "BP"
return {
"is_twin": True,
"aspect": aspect,
"method_key": f"Twin-{aspect}",
"method_label": f"genomenet-twin ({aspect})",
"threshold": 0.70,
"dim_info": f"1024-dim Twin-{aspect}",
}
return {
"is_twin": False,
"aspect": None,
"method_key": "ESM2 baseline",
"method_label": "ESM2 baseline",
"threshold": 0.90,
"dim_info": "1280-dim ESM2",
}
def _crispr_verdict_card(label, conf, top_cos, threshold, method_label):
conf_colour = {"high": "#059669", "medium": "#d97706", "low": "#dc2626"}.get(conf, "#6b7280")
return (
f"### CRISPR characterization ({method_label})\n"
f"<div style='border-left:4px solid {conf_colour};padding:10px 14px;"
f"background:#fafafa;border-radius:4px;margin:6px 0;'>"
f"<div style='font-size:15px;'><b>{label}</b></div>"
f"<div style='font-size:12px;color:#6b7280;margin-top:4px;'>"
f"confidence: <b style='color:{conf_colour}'>{conf.upper()}</b> · "
f"top-1 cosine = {top_cos:+.3f} · heuristic threshold = {threshold:.2f}"
f"</div></div>"
)
def _make_crispr_query_plot(method_key, method_label, ref_meta, sims, top):
xy_ref = get_crispr_umap(method_key)
if xy_ref is None:
return None
w = np.exp((sims[top[:5]] - sims[top[:5]].max()) * 8.0)
w /= w.sum()
q_xy = (xy_ref[top[:5]] * w[:, None]).sum(axis=0)
from matplotlib.lines import Line2D
all_fams = list(dict.fromkeys(ref_meta["family"].tolist()))
cmap = plt.get_cmap("tab20", len(all_fams))
fam_colour = {f: cmap(i) for i, f in enumerate(all_fams)}
group_marker = {"cas": "o", "acr": "^"}
fig, ax = plt.subplots(figsize=(9, 6.5))
for i, r in ref_meta.iterrows():
ax.scatter(
xy_ref[i, 0], xy_ref[i, 1],
color=fam_colour[r["family"]],
marker=group_marker.get(str(r["group"]), "o"),
s=48, edgecolor="black", linewidth=0.35, alpha=0.85, zorder=2,
)
for i in top[:3]:
ax.plot([q_xy[0], xy_ref[i, 0]], [q_xy[1], xy_ref[i, 1]],
color="#444", lw=0.6, linestyle="--", alpha=0.6, zorder=3)
ax.scatter(q_xy[0], q_xy[1], marker="*", s=320, color="#111",
edgecolor="gold", linewidth=2.0, zorder=10)
handles = [
Line2D([0], [0], marker="o", color="w", markerfacecolor=fam_colour[f],
markeredgecolor="black", markersize=7, label=f)
for f in all_fams
]
handles.append(Line2D([0], [0], marker="*", color="w", markerfacecolor="#111",
markeredgecolor="gold", markersize=14, label="query"))
ax.legend(handles=handles, loc="center left", bbox_to_anchor=(1.02, 0.5),
fontsize=7.2, frameon=False)
ax.set_xlabel("projection-1")
ax.set_ylabel("projection-2")
ax.set_title(f"Query placement in the mixed Cas/Acr reference ({method_label})")
ax.grid(alpha=0.25, linestyle=":")
plt.tight_layout()
return fig
def _lookup_handler(sequence, method_choice, index_choice):
sequence = strip_fasta_header(sequence.strip())
valid, err = validate_protein(sequence)
if not valid:
return f"**Error**: {err}", "", None, _lookup_empty_html
trunc_note = (f"> ⚠️ Query truncated from {len(sequence)} to {ESM2_MAX_LEN} aa "
f"(ESM2 limit).\n\n" if len(sequence) > ESM2_MAX_LEN else "")
spec = _lookup_method_spec(method_choice)
if index_choice.startswith("UniRef50"):
try:
if spec["is_twin"]:
if spec["aspect"] != "BP":
return (
f"{trunc_note}**Twin-MF × UniRef50** is not packaged yet.\n\n"
"The full UniRef50 FAISS index currently available for the Twin model "
"uses the **BP** aspect. Use `genomenet-twin (BP)` for full UniRef50 "
"lookup, or use Twin-MF against the curated CRISPR reference.",
"",
None,
_lookup_empty_html,
)
query_emb = embed_twin(sequence, aspect="BP")
faiss_index_name = "twin-bp"
else:
query_emb = embed_esm2(sequence)
faiss_index_name = "esm2"
faiss_cfg = FAISS_CONFIGS[faiss_index_name]
# Large pool (no metadata) for context, small display-subset (with metadata).
top_pool = search_faiss(query_emb, k=POOL_K, fetch_metadata=False, index_name=faiss_index_name)
bot_pool = search_faiss(query_emb, k=POOL_K, negate=True, fetch_metadata=False, index_name=faiss_index_name)
# Now fetch metadata for display subsets only
top_display_ids = top_pool["uniref50_id"].head(DISPLAY_K).tolist()
bot_display_ids = bot_pool["uniref50_id"].head(DISPLAY_K).tolist()
meta = fetch_uniref_metadata(top_display_ids + bot_display_ids)
def _to_display(pool_df, ids_subset, rank_start=1):
sub = pool_df[pool_df["uniref50_id"].isin(ids_subset)].copy()
sub = sub.set_index("uniref50_id").reindex(ids_subset).reset_index()
sub["rank"] = range(rank_start, rank_start + len(sub))
return pd.DataFrame({
"rank": sub["rank"],
"id": sub["uniref50_id"],
"cosine": sub["cosine"],
"description": [meta.get(i, "") for i in sub["uniref50_id"]],
"link": ["https://www.uniprot.org/uniref/" + i for i in sub["uniref50_id"]],
})
top_df = _to_display(top_pool, top_display_ids)
bot_df = _to_display(bot_pool, bot_display_ids)
# Histogram pools
top_25_cos = top_pool["cosine"].head(DISPLAY_K).tolist()
top_ctx_cos = top_pool["cosine"].iloc[DISPLAY_K:].tolist()
bot_25_cos = bot_pool["cosine"].head(DISPLAY_K).tolist()
bot_ctx_cos = bot_pool["cosine"].iloc[DISPLAY_K:].tolist()
all_cos = top_df["cosine"].tolist() + bot_df["cosine"].tolist()
lo, hi = (min(all_cos), max(all_cos)) if all_cos else (0.0, 1.0)
info = (
f"{trunc_note}"
f"**UniRef50 × {faiss_cfg['label']}** — top-{DISPLAY_K} most similar + bottom-{DISPLAY_K} "
f"most anti-correlated clusters (cosine on L2-normalized {faiss_cfg['dim_info']} embeddings). "
f"The histogram plots all {len(top_pool) + len(bot_pool)} retrieved extreme hits "
f"({len(top_pool)} top-neighbour hits + {len(bot_pool)} anti-correlated hits); "
f"metadata is fetched only for the table rows. \n"
f"Top range: **{top_df['cosine'].max():+.3f}** – **{top_df['cosine'].min():+.3f}** · "
f"Bottom range: **{bot_df['cosine'].max():+.3f}** – **{bot_df['cosine'].min():+.3f}**"
)
html = (
_distribution_svg(top_25_cos, bot_25_cos, top_ctx_cos, bot_ctx_cos)
+ _hits_table_section(top_df, f"🟢 Top-{DISPLAY_K} most similar", lo, hi)
+ _hits_table_section(bot_df, f"🔴 Bottom-{DISPLAY_K} most anti-correlated", lo, hi)
)
return info, "", None, html
except Exception as e:
return (f"{trunc_note}**FAISS lookup failed**: {str(e).splitlines()[0]}\n\n"
f"The UniRef50 FAISS index may not be available yet.",
"",
None,
_lookup_empty_html)
else: # CRISPR curated reference
try:
if spec["is_twin"]:
query_emb = embed_twin(sequence, aspect=spec["aspect"])
else:
query_emb = embed_esm2(sequence)
method_label = spec["method_label"]
method_key = spec["method_key"]
emb_dim_info = spec["dim_info"]
ref_emb, _ = get_crispr_reference(method_key)
_, ref_meta = get_crispr_reference(method_key)
n_ref = ref_emb.shape[0]
q = np.asarray(query_emb, dtype=np.float32).reshape(-1)
q = q / (np.linalg.norm(q) + 1e-9)
sims = ref_emb @ q
order = np.argsort(-sims)
k_eff = min(DISPLAY_K, n_ref // 2)
top_idx = order[:k_eff]
bot_idx = order[-k_eff:][::-1]
def _display_rows(indices):
rows = []
for rank, idx in enumerate(indices, 1):
m = ref_meta.iloc[idx]
rows.append({
"rank": rank,
"id": _crispr_display_id(m),
"cosine": round(float(sims[idx]), 4),
"description": (
f"{_cell_text(m.get('family', ''))} · "
f"{_cell_text(m.get('type', ''))} · "
f"{_cell_text(m.get('organism', ''))}"
),
"link": _crispr_protein_link(m),
})
return pd.DataFrame(rows)
top_df, bot_df = _display_rows(top_idx), _display_rows(bot_idx)
# Context for histogram = the middle proteins (not in top or bottom k)
middle = [float(sims[i]) for i in order[k_eff:n_ref - k_eff]]
all_cos = top_df["cosine"].tolist() + bot_df["cosine"].tolist()
lo, hi = (min(all_cos), max(all_cos)) if all_cos else (0.0, 1.0)
top10_idx = order[:10]
top_rows = []
for rank, idx in enumerate(top10_idx, 1):
m = ref_meta.iloc[idx]
top_rows.append({
"rank": rank,
"id": _crispr_display_id(m),
"cosine": float(sims[idx]),
"family": _cell_text(m.get("family", "")),
"group": _cell_text(m.get("group", "")),
"organism": _cell_text(m.get("organism", "")),
"name": _cell_text(m.get("name", "")),
})
top_cos = top_rows[0]["cosine"]
label, conf = _crispr_verdict(top_rows, top_cos, spec["threshold"])
verdict_md = _crispr_verdict_card(label, conf, top_cos, spec["threshold"], method_label)
fig = _make_crispr_query_plot(method_key, method_label, ref_meta, sims, top10_idx)
info = (
f"{trunc_note}"
f"**CRISPR curated × {method_label}** — {n_ref}-protein mixed Cas/Acr set "
f"(175 Pfam-verified Cas + 79 Anti-CRISPRdb v3/local Acr proteins). "
f"Top-{k_eff} most similar + bottom-{k_eff} most anti-correlated (cosine on "
f"{emb_dim_info} embeddings). Histogram also shows the {len(middle)} "
f"middle proteins (grey). \n"
f"Top range: **{top_df['cosine'].max():+.3f}** – **{top_df['cosine'].min():+.3f}** · "
f"Bottom range: **{bot_df['cosine'].max():+.3f}** – **{bot_df['cosine'].min():+.3f}**"
)
html = (
_distribution_svg(top_df["cosine"].tolist(), bot_df["cosine"].tolist(), middle, [])
+ _hits_table_section(top_df, f"🟢 Top-{k_eff} most similar in CRISPR set", lo, hi)
+ _hits_table_section(bot_df, f"🔴 Bottom-{k_eff} most anti-correlated in CRISPR set", lo, hi)
)
return info, verdict_md, fig, html
except Exception as e:
return (f"{trunc_note}**CRISPR reference lookup failed**: {str(e).splitlines()[0]}\n\n"
f"The curated CRISPR embeddings may not be packaged into this Space yet.",
"",
None,
_lookup_empty_html)
lookup_btn.click(
lambda m, r: (f"⏳ Searching {m} × {r}…", "", None, _lookup_empty_html),
inputs=[lookup_method_radio, lookup_index_radio],
outputs=[lookup_info, lookup_verdict, lookup_plot, lookup_hits],
show_progress="hidden",
).then(
_lookup_handler,
inputs=[lookup_seq_input, lookup_method_radio, lookup_index_radio],
outputs=[lookup_info, lookup_verdict, lookup_plot, lookup_hits],
api_name="lookup",
show_progress="minimal",
)
gr.Examples(
examples=[
# UniRef50 × ESM2 (full-scale FAISS search over 4.3M clusters)
[_CAS1_ECOLI, "ESM2 baseline", "UniRef50"],
[_SPCAS9, "ESM2 baseline", "UniRef50"],
[_ACRIIA4, "ESM2 baseline", "UniRef50"],
[_CAS1_ECOLI, "genomenet-twin (BP)", "UniRef50"],
[_SPCAS9, "genomenet-twin (BP)", "UniRef50"],
# CRISPR curated × Twin-MF (recommended for top-hit characterization)
[_CAS1_ECOLI, "genomenet-twin (MF)", "CRISPR curated"],
[_SPCAS9, "genomenet-twin (MF)", "CRISPR curated"],
[_FNCAS12A, "genomenet-twin (MF)", "CRISPR curated"],
[_LSHCAS13A, "genomenet-twin (MF)", "CRISPR curated"],
[_ACRIIA4, "genomenet-twin (MF)", "CRISPR curated"],
# CRISPR curated × ESM2 (sequence-level baseline for comparison)
[_SPCAS9, "ESM2 baseline", "CRISPR curated"],
[_ACRIIA4, "ESM2 baseline", "CRISPR curated"],
],
inputs=[lookup_seq_input, lookup_method_radio, lookup_index_radio],
example_labels=[
"Cas1 × UniRef50 (ESM2) — adaptation, E. coli",
"SpCas9 × UniRef50 (ESM2) — Type II-A effector",
"AcrIIA4 × UniRef50 (ESM2) — anti-CRISPR",
"Cas1 × UniRef50 (Twin-BP) — full functional-index lookup",
"SpCas9 × UniRef50 (Twin-BP) — full functional-index lookup",
"Cas1 × CRISPR (Twin-MF) — adaptation, E. coli",
"SpCas9 × CRISPR (Twin-MF) — Type II-A effector",
"FnCas12a × CRISPR (Twin-MF) — Type V-A effector",
"LshCas13a × CRISPR (Twin-MF) — Type VI-A, RNA-targeting",
"AcrIIA4 × CRISPR (Twin-MF) — anti-Cas9 inhibitor",
"SpCas9 × CRISPR (ESM2) — sequence-level baseline",
"AcrIIA4 × CRISPR (ESM2) — sequence-level baseline",
],
label="Example queries — click to load a query × method × reference combo",
)
if False:
gr.Markdown(
"### Characterize an unknown protein\n\n"
"Designed for triaging uncharacterised ORFs: is this protein CRISPR-adjacent at "
"all, and if so, which family does it look like? Two sections:\n"
"1. **Pre-computed gallery** — we fetched 24 genuinely uncharacterised proteins "
"from UniProt (phage and bacterial hypotheticals, prime candidates for novel "
"Cas / anti-CRISPR discoveries) and ran both ESM2 and genomenet-twin (MF) on each. "
"Results are shown below in an interactive 3-D UMAP plus a full table.\n"
"2. **Bring your own sequence** — paste any protein at the bottom to run the same "
"analysis interactively."
)
# --- Section 1: pre-computed gallery -------------------------------
# Summary stats
gallery_data = get_unknown_results()
n_unknown = gallery_data.get("n_queries", 0)
if n_unknown > 0:
recs = gallery_data["records"]
def _count_by_verdict(method_key):
pred = sum(1 for r in recs if "Predicted" in r[method_key]["verdict"])
cris = sum(1 for r in recs if r[method_key]["confidence"] != "low")
return pred, cris
pred_mf, cris_mf = _count_by_verdict("twin-mf")
pred_es, cris_es = _count_by_verdict("esm2")
gr.Markdown(
f"#### 1. Pre-computed gallery · {n_unknown} uncharacterised proteins\n\n"
f"| method | 'Predicted: &lt;family&gt;' | any CRISPR signal (above threshold) |\n"
f"|---|---|---|\n"
f"| genomenet-twin (MF) | **{pred_mf}** / {n_unknown} | **{cris_mf}** / {n_unknown} |\n"
f"| ESM2 baseline | {pred_es} / {n_unknown} | {cris_es} / {n_unknown} |\n\n"
f"Twin-MF flags substantially more phage hypotheticals as CRISPR-related than "
f"ESM2 — expected given the benchmark results, and biologically plausible: "
f"phages often carry anti-CRISPR proteins to evade the host's CRISPR-Cas "
f"defence. Rebuild this gallery from "
f"`scripts/analyses/crispr/precompute_unknowns.py` (GPU job, ~3 min)."
)
gr.Markdown(
"**Signal-strength benchmark — ESM2 vs Twin-MF on the same unknowns.** "
"Three views of the same data: histogram of top-1 cosines, cumulative "
"\"fraction of unknowns with cosine ≥ x\", and verdict-confidence bucket "
"counts. Twin-MF consistently shifts the distribution rightward and "
"promotes more unknowns into the medium/high-confidence buckets than "
"ESM2 does — the direct answer to *how much more signal does Twin find "
"than ESM2 on actual unknowns*."
)
gr.Image(
value="data/benchmark/crispr_unknowns_distribution.png",
label=None, show_label=False, container=False,
)
gallery_method_radio = gr.Radio(
choices=["genomenet-twin (MF)", "ESM2 baseline"],
value="genomenet-twin (MF)",
label="Gallery view method",
info="Switches the 3-D UMAP overview and the verdict column in the table below.",
)
gallery_plot = gr.Plot(label="3-D UMAP — reference (by family) + unknowns (by verdict)")
gallery_table = gr.HTML(
"<div style='padding:12px;color:#6b7280;'>Click <b>load gallery</b> to render the pre-computed unknown-protein overview.</div>"
)
gallery_load_btn = gr.Button("load gallery", variant="secondary")
def _make_gallery_3d_plot(method_choice):
import plotly.graph_objects as go
method_key = "Twin-MF" if method_choice.startswith("genomenet") else "ESM2 baseline"
mkey = "twin-mf" if method_key == "Twin-MF" else "esm2"
xyz = get_crispr_umap_3d(method_key)
if xyz is None:
return go.Figure()
_, ref_meta = get_crispr_reference(method_key)
fams = list(dict.fromkeys(ref_meta["family"].tolist()))
# Consistent palette using Plotly's D3 then Pastel
palette = (
px_colors_pool := [
"#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b",
"#e377c2", "#7f7f7f", "#bcbd22", "#17becf", "#aec7e8", "#ffbb78",
"#98df8a", "#ff9896", "#c5b0d5", "#c49c94", "#f7b6d3", "#c7c7c7",
"#dbdb8d", "#9edae5",
]
)
fam_colour = {f: palette[i % len(palette)] for i, f in enumerate(fams)}
traces = []
for fam in fams:
mask = (ref_meta["family"] == fam).values
sub = ref_meta[mask]
hover = [f"<b>{r['acc']}</b><br>{r['name']}<br>{r['organism']}<br>family: {fam}"
for _, r in sub.iterrows()]
traces.append(go.Scatter3d(
x=xyz[mask, 0], y=xyz[mask, 1], z=xyz[mask, 2],
mode="markers",
marker=dict(size=4, color=fam_colour[fam], line=dict(width=0.5, color="black")),
name=fam, legendgroup="ref", legendgrouptitle_text="reference (families)",
hovertext=hover, hoverinfo="text",
))
# Unknowns: approximate 3-D position as weighted centroid of top-5 reference 3-D coords
if gallery_data["records"]:
ux, uy, uz, htxt, ucol = [], [], [], [], []
conf_colour = {"low": "#dc2626", "medium": "#d97706", "high": "#059669"}
for r in gallery_data["records"]:
result = r[mkey]
top_ids = [t["id"] for t in result["top_10"][:5]]
top_sims = np.array([t["cosine"] for t in result["top_10"][:5]])
idxs = []
for tid in top_ids:
m = ref_meta.index[ref_meta["acc"] == tid].tolist()
if m:
idxs.append(m[0])
if len(idxs) < 3:
continue
w = np.exp((top_sims[:len(idxs)] - top_sims[:len(idxs)].max()) * 8.0)
w /= w.sum()
pos = (xyz[idxs] * w[:, None]).sum(axis=0)
ux.append(pos[0]); uy.append(pos[1]); uz.append(pos[2])
verdict_short = result["verdict"].split(";")[0][:80]
htxt.append(f"<b>{r['acc']}</b> ({result['confidence'].upper()})<br>"
f"{r['name']}<br>{r['organism']}<br>"
f"<b>verdict:</b> {verdict_short}<br>"
f"top-1: {result['top_10'][0]['family']} (cos = {result['top_cosine']:+.3f})")
ucol.append(conf_colour[result["confidence"]])
traces.append(go.Scatter3d(
x=ux, y=uy, z=uz, mode="markers",
marker=dict(size=9, symbol="diamond", color=ucol,
line=dict(width=1.5, color="gold")),
name="unknowns (by confidence)", legendgroup="unk",
legendgrouptitle_text="unknowns (verdict confidence)",
hovertext=htxt, hoverinfo="text",
))
fig = go.Figure(data=traces)
fig.update_layout(
scene=dict(xaxis_title="UMAP-1", yaxis_title="UMAP-2", zaxis_title="UMAP-3"),
height=600, margin=dict(l=0, r=0, t=30, b=0),
title=f"3-D UMAP · reference (84) + unknowns (24) · {method_choice}",
legend=dict(itemsizing="constant", groupclick="toggleitem"),
)
return fig
def _make_gallery_table(method_choice):
mkey = "twin-mf" if method_choice.startswith("genomenet") else "esm2"
rows = []
# Sort by top cosine descending
recs_sorted = sorted(gallery_data["records"], key=lambda r: -r[mkey]["top_cosine"])
for r in recs_sorted:
res = r[mkey]
conf = res["confidence"]
col = {"high": "#059669", "medium": "#d97706", "low": "#dc2626"}[conf]
top1 = res["top_10"][0]
verdict_short = res["verdict"][:90]
rows.append(
f"<tr>"
f"<td style='padding:6px 10px;font-family:ui-monospace,monospace;white-space:nowrap;'>"
f"<a href='{r['uniprot']}' target='_blank' style='color:#2563eb;text-decoration:none;'>{r['acc']}</a></td>"
f"<td style='padding:6px 10px;font-size:12px;color:#374151;'>{r['organism'][:42]}</td>"
f"<td style='padding:6px 10px;font-size:12px;color:#6b7280;'>{r['length']}</td>"
f"<td style='padding:6px 10px;font-size:12px;color:#6b7280;'>{r['source']}</td>"
f"<td style='padding:6px 10px;'><span style='background:{col};color:white;"
f"padding:2px 6px;border-radius:3px;font-size:11px;'>{conf.upper()}</span></td>"
f"<td style='padding:6px 10px;font-size:12px;'>{verdict_short}</td>"
f"<td style='padding:6px 10px;font-family:ui-monospace,monospace;font-size:12px;'>"
f"{top1['family']} · {res['top_cosine']:+.3f}</td>"
f"</tr>"
)
return (
f"<h4 style='margin:12px 0 6px 0;'>Full gallery (sorted by top-1 cosine, "
f"{method_choice})</h4>"
f"<table style='width:100%;border-collapse:collapse;font-size:13px;'>"
f"<thead><tr style='text-align:left;border-bottom:1px solid #e5e7eb;color:#6b7280;font-size:12px;'>"
f"<th style='padding:6px 10px;'>UniProt</th>"
f"<th style='padding:6px 10px;'>organism</th>"
f"<th style='padding:6px 10px;'>aa</th>"
f"<th style='padding:6px 10px;'>source</th>"
f"<th style='padding:6px 10px;'>confidence</th>"
f"<th style='padding:6px 10px;'>verdict</th>"
f"<th style='padding:6px 10px;'>top hit · cosine</th>"
f"</tr></thead>"
f"<tbody>{''.join(rows)}</tbody></table>"
f"<p style='font-size:11px;color:#6b7280;margin-top:8px;'>"
f"Source queries used to fetch the unknowns: "
f"phage hypotheticals ≤200 aa, phage hypotheticals 200–500 aa, "
f"Streptococcus-phage and Pseudomonas-phage hypotheticals via "
f"<code>virus_host_id</code> on UniProt, plus two bacterial-hypothetical "
f"control sets from bacteria with experimental evidence (E. coli, B. subtilis, "
f"M. tuberculosis). See "
f"<code>scripts/analyses/crispr/precompute_unknowns.py</code>."
f"</p>"
)
gallery_method_radio.change(
fn=lambda m: (_make_gallery_3d_plot(m), _make_gallery_table(m)),
inputs=[gallery_method_radio],
outputs=[gallery_plot, gallery_table],
show_progress="minimal",
)
gallery_load_btn.click(
fn=lambda m: (_make_gallery_3d_plot(m), _make_gallery_table(m)),
inputs=[gallery_method_radio],
outputs=[gallery_plot, gallery_table],
show_progress="minimal",
)
# --- Section 2: bring-your-own-sequence ----------------------------
gr.Markdown("#### 2. Bring your own sequence")
with gr.Row():
with gr.Column(scale=1, min_width=320):
char_seq_input = gr.Textbox(
label="Protein sequence",
placeholder="Paste protein sequence (amino acids)...",
lines=6,
value=EXAMPLE_PROTEIN,
info="FASTA or raw sequence; > 1022 aa is truncated",
)
char_method_radio = gr.Radio(
choices=["genomenet-twin (MF)", "ESM2 baseline"],
value="genomenet-twin (MF)",
label="Method",
info="Twin-MF is the recommended aspect for CRISPR classification.",
)
char_btn = gr.Button("characterize", variant="primary")
with gr.Column(scale=2, min_width=400):
char_verdict = gr.Markdown()
char_plot = gr.Plot(label="Query placement on CRISPR UMAP")
char_hits = gr.HTML()
def _characterize_handler(sequence, method_choice):
sequence = strip_fasta_header(sequence.strip())
valid, err = validate_protein(sequence)
if not valid:
return (f"**Error**: {err}", None, "")
trunc_note = (f"> Query truncated from {len(sequence)} to {ESM2_MAX_LEN} aa "
f"(ESM2 limit).\n\n" if len(sequence) > ESM2_MAX_LEN else "")
if method_choice.startswith("genomenet-twin"):
query_emb = embed_twin(sequence, aspect="MF")
method_key = "Twin-MF"
threshold = 0.70
method_label = "genomenet-twin (MF)"
else:
query_emb = embed_esm2(sequence)
method_key = "ESM2 baseline"
threshold = 0.90
method_label = "ESM2 baseline"
# Rank the reference
ref_emb, ref_meta = get_crispr_reference(method_key)
q = query_emb / (np.linalg.norm(query_emb) + 1e-9)
sims = ref_emb @ q
order = np.argsort(-sims)
top = order[:10]
top_rows = []
for rank, idx in enumerate(top, 1):
m = ref_meta.iloc[idx]
top_rows.append({
"rank": rank, "id": m["acc"], "cosine": float(sims[idx]),
"family": m["family"], "group": m["group"],
"organism": m["organism"], "name": m["name"],
})
top_cos = top_rows[0]["cosine"]
label, conf = _crispr_verdict(top_rows, top_cos, threshold)
conf_colour = {"high": "#059669", "medium": "#d97706", "low": "#dc2626"}.get(conf, "#6b7280")
verdict_md = (
f"{trunc_note}"
f"### Verdict ({method_label})\n"
f"<div style='border-left:4px solid {conf_colour};padding:10px 14px;"
f"background:#fafafa;border-radius:4px;margin:6px 0;'>"
f"<div style='font-size:15px;'><b>{label}</b></div>"
f"<div style='font-size:12px;color:#6b7280;margin-top:4px;'>"
f"confidence: <b style='color:{conf_colour}'>{conf.upper()}</b> · "
f"top-1 cosine = {top_cos:+.3f} · threshold = {threshold:.2f}"
f"</div></div>"
)
# UMAP scatter (only available for ESM2 baseline and Twin-MF)
xy_ref = get_crispr_umap(method_key)
fig = None
if xy_ref is not None:
# Weighted centroid of top-5 neighbours
w = np.exp((sims[top[:5]] - sims[top[:5]].max()) * 8.0)
w /= w.sum()
q_xy = (xy_ref[top[:5]] * w[:, None]).sum(axis=0)
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
all_fams = list(dict.fromkeys(ref_meta["family"].tolist()))
cmap = plt.get_cmap("tab20", len(all_fams))
fam_colour = {f: cmap(i) for i, f in enumerate(all_fams)}
group_marker = {"cas": "o", "acr": "^"}
fig, ax = plt.subplots(figsize=(9, 6.5))
for i, r in ref_meta.iterrows():
ax.scatter(xy_ref[i, 0], xy_ref[i, 1],
color=fam_colour[r["family"]],
marker=group_marker.get(r["group"], "o"),
s=55, edgecolor="black", linewidth=0.4, alpha=0.85, zorder=2)
# Lines to top-3 neighbours
for i in top[:3]:
ax.plot([q_xy[0], xy_ref[i, 0]], [q_xy[1], xy_ref[i, 1]],
color="#444", lw=0.6, linestyle="--", alpha=0.6, zorder=3)
# Query star
ax.scatter(q_xy[0], q_xy[1], marker="*", s=320, color="#111",
edgecolor="gold", linewidth=2.0, zorder=10)
handles = [Line2D([0], [0], marker="o", color="w", markerfacecolor=fam_colour[f],
markeredgecolor="black", markersize=8, label=f) for f in all_fams]
handles.append(Line2D([0], [0], marker="*", color="w", markerfacecolor="#111",
markeredgecolor="gold", markersize=14, label="query"))
ax.legend(handles=handles, loc="center left", bbox_to_anchor=(1.02, 0.5),
fontsize=7.5, frameon=False)
ax.set_xlabel("UMAP-1"); ax.set_ylabel("UMAP-2")
ax.set_title(f"Query placed at softmax-weighted top-5 centroid ({method_label})")
ax.grid(alpha=0.25, linestyle=":")
plt.tight_layout()
# Top-10 HTML table
all_cos = [r["cosine"] for r in top_rows]
lo, hi = min(all_cos), max(all_cos)
def _row(r):
pct = max(0.0, min(100.0, (r["cosine"] - lo) / (hi - lo + 1e-9) * 100.0))
link = (f"<a href='https://www.uniprot.org/uniprotkb/{r['id']}' "
f"target='_blank' style='color:#2563eb;text-decoration:none;'>↗</a>")
bar = (f"<div style='display:flex;align-items:center;gap:10px;'>"
f"<div style='position:relative;width:160px;height:10px;"
f"background:linear-gradient(to right,#f87171,#facc15,#4ade80);border-radius:2px;'>"
f"<div style='position:absolute;left:{pct:.1f}%;top:-3px;width:2px;height:16px;"
f"background:#111;box-shadow:0 0 0 2px #fff;transform:translateX(-50%);'></div>"
f"</div>"
f"<span style='font-family:ui-monospace,monospace;font-size:12px;min-width:54px;'>{r['cosine']:+.4f}</span>"
f"</div>")
return (f"<tr><td style='padding:5px 10px;color:#555;'>{r['rank']}</td>"
f"<td style='padding:5px 10px;font-family:ui-monospace,monospace;'>{r['id']}</td>"
f"<td style='padding:5px 10px;'>{bar}</td>"
f"<td style='padding:5px 10px;font-weight:600;'>{r['family']}</td>"
f"<td style='padding:5px 10px;font-size:12px;color:#374151;'>{r['organism'][:50]}</td>"
f"<td style='padding:5px 10px;'>{link}</td></tr>")
table = (f"<h4 style='margin:12px 0 6px 0;'>Top-10 nearest reference proteins</h4>"
f"<table style='width:100%;border-collapse:collapse;font-size:13px;'>"
f"<thead><tr style='text-align:left;border-bottom:1px solid #e5e7eb;color:#6b7280;font-size:12px;'>"
f"<th style='padding:5px 10px;'>#</th><th style='padding:5px 10px;'>ID</th>"
f"<th style='padding:5px 10px;'>cosine</th><th style='padding:5px 10px;'>family</th>"
f"<th style='padding:5px 10px;'>organism</th><th style='padding:5px 10px;'></th></tr></thead>"
f"<tbody>{''.join(_row(r) for r in top_rows)}</tbody></table>")
return verdict_md, fig, table
char_btn.click(
lambda m: (f"<div style='padding:16px;color:#666;'>⏳ Characterizing against CRISPR reference with {m}…</div>", None, ""),
inputs=[char_method_radio],
outputs=[char_verdict, char_plot, char_hits],
show_progress="hidden",
).then(
_characterize_handler,
inputs=[char_seq_input, char_method_radio],
outputs=[char_verdict, char_plot, char_hits],
api_name="characterize",
show_progress="minimal",
)
gr.Examples(
examples=[
[_SPCAS9, "genomenet-twin (MF)"],
[_FNCAS12A, "genomenet-twin (MF)"],
[_LSHCAS13A, "genomenet-twin (MF)"],
[_CAS1_ECOLI, "genomenet-twin (MF)"],
[_ACRIIA4, "genomenet-twin (MF)"],
[_SPCAS9, "ESM2 baseline"],
[_ACRIIA4, "ESM2 baseline"],
],
inputs=[char_seq_input, char_method_radio],
example_labels=[
"SpCas9 — expected: Cas9 (Type II-A effector)",
"FnCas12a — expected: Cas12 (Type V-A effector)",
"LshCas13a — expected: Cas13 (Type VI-A, RNA-targeting)",
"Cas1 (E. coli) — expected: Cas1 (adaptation)",
"AcrIIA4 — expected: AcrIIA family (anti-Cas9)",
"SpCas9 — same query, ESM2 method (compare)",
"AcrIIA4 — same query, ESM2 method (compare)",
],
label="Example queries — each known protein; expected verdict is in the label",
)
with gr.Tab("Benchmark"):
gr.Markdown("""
## CRISPR protein benchmark
**Question.** Does the GO-supervised `genomenet-twin` embedding recover useful
CRISPR protein structure beyond the pretrained ESM2 representation?
This tab evaluates only **protein sequences**. It does not use the CRISPR array
detector and does not run DIAMOND, MMseqs, BLAST, or any other sequence-search
baseline. The benchmark asks whether cosine geometry in each embedding space
recovers known CRISPR protein labels.
The first UniProt mixed benchmark is now superseded: broad text queries pulled
known contaminants, especially acriflavin-resistance proteins from `AcrIF*`, and
a few Cas entries were mislabeled. The current benchmark therefore separates the
problem into clean reference views and then recombines them for the main figure.
### Reference sets
| reference view | proteins | source and purpose |
|---|---:|---|
| **Mixed Cas + Acr** | **254** | Main overview. Concatenates 175 Pfam-verified Cas proteins with 79 Anti-CRISPRdb v3 Acr proteins. |
| **Acr-only** | **79** | Anti-CRISPR family retrieval. Uses typed Anti-CRISPRdb v3 records plus local curated AcrIB examples. |
| **Cas-only** | **175** | Within-Cas subfamily resolution. Uses UniProt Pfam cross-references for Cas1-Cas12. |
### How to read the metrics
| metric | interpretation |
|---|---|
| **within − between** | Mean within-family cosine minus between-family cosine. High values mean block structure is strong, but this is not by itself a retrieval metric. |
| **5-NN recall** | For each protein, fraction of its nearest neighbours that share the family label, adjusted for small families. This is the practical "find more proteins like this" metric. |
| **LOO top-1** | Leave-one-out nearest-neighbour family assignment accuracy. This asks whether the closest other protein has the correct family. |
| **LOO MRR** | Leave-one-out mean reciprocal rank of the first correct-family neighbour. Higher means correct-family proteins appear earlier in the ranked list. |
| **AUC family** | Pairwise same-family versus different-family discrimination from cosine similarity. |
| **AUC group** | Pairwise broader-group discrimination. In the mixed reference this is Cas versus Acr; in the Acr-only reference it is inhibited CRISPR class. |
| **silhouette** | Scale-normalized cluster quality. Positive values mean proteins are closer to their own family than to neighbouring families. |
The important point is that these metrics answer different biological questions.
Nearest-neighbour metrics evaluate annotation and retrieval. Pairwise AUCs test
global ranking of same-label pairs. Silhouette tests cluster geometry. A single
"winner" is therefore less informative than the pattern across metrics.
""")
gr.Markdown("""
---
## Figure 1. Mixed Cas + anti-CRISPR benchmark
The mixed reference combines Pfam-verified Cas proteins and Anti-CRISPRdb v3 Acr
proteins without re-embedding. This is the closest view to a general CRISPR
protein benchmark: the model must separate Cas from Acr while also preserving
family-level structure inside each group.
""")
gr.Image(value="data/benchmark/crispr_umap_esm2_vs_twin.png", label="Mixed reference projection", show_label=True, container=False)
gr.Markdown("""
**Figure 1a. Two-dimensional projection of the mixed reference.** Each point is
one protein and colours encode family labels. The projection is a visualization
aid, not the primary statistic: distances are distorted by dimensionality
reduction. The useful readout is whether same-label proteins form coherent
neighbourhoods and whether Cas and Acr regions are visibly separated.
""")
gr.Image(value="data/benchmark/crispr_esm2_vs_twin_heatmap.png", label="Mixed reference heatmap", show_label=True, container=False)
gr.Markdown("""
**Figure 1b. All-by-all cosine similarity heatmap.** Rows and columns are
ordered by family. A strong embedding produces diagonal blocks: high similarity
within a family and lower similarity between unrelated families. Twin-MF gives a
much stronger within-minus-between contrast than ESM2, but the retrieval tables
below are needed to decide whether that block structure helps annotation.
""")
gr.Image(value="data/benchmark/crispr_aspect_comparison.png", label="Mixed reference model comparison", show_label=True, container=False)
gr.Markdown("""
**Table 1. Quantitative performance on the mixed reference.**
| model | within − between | 5-NN recall | LOO top-1 | LOO MRR | AUC family | AUC Cas-vs-Acr | silhouette |
|---|---:|---:|---:|---:|---:|---:|---:|
| ESM2 baseline | 0.038 | 0.608 | 0.724 | 0.795 | 0.823 | 0.700 | **0.131** |
| Twin-BP | 0.169 | **0.618** | 0.697 | 0.771 | 0.829 | 0.689 | 0.082 |
| Twin-CC | **0.404** | 0.480 | 0.618 | 0.713 | 0.784 | 0.554 | 0.016 |
| Twin-MF | 0.321 | 0.604 | **0.748** | **0.808** | **0.843** | **0.711** | 0.084 |
**Interpretation.** Twin-MF is the strongest aspect for family assignment by
leave-one-out nearest-neighbour tests and for pairwise family or Cas-vs-Acr
separation. Twin-BP is marginally best for 5-neighbour recall. ESM2 has the best
silhouette, meaning its clusters are more compact relative to their nearest
alternative cluster in this reference. This is a mixed result: Twin-MF improves
several annotation-oriented ranking metrics, but ESM2 remains competitive and
should stay as the baseline in any manuscript claim.
""")
gr.Image(value="data/benchmark/crispr_subcluster_comparison.png", label="Mixed reference family retrieval detail", show_label=True, container=False)
gr.Markdown("""
**Figure 1c. Per-family nearest-neighbour behaviour.** The lower panel shows
where the aggregate metrics come from. Twin-MF improves retrieval for several
Cas families, including Cas3 and Cas4, but loses for others such as Cas8 and
some Acr families. This argues against reporting only a global average: the
embedding improvement is family-dependent.
""")
gr.Image(value="data/benchmark/crispr_silhouette.png", label="Mixed reference silhouette", show_label=True, container=False)
gr.Markdown("""
**Figure 1d. Silhouette analysis.** Silhouette scores summarize whether each
protein is closer to its own family than to the nearest other family. ESM2 has
the highest mixed-reference mean silhouette, while Twin-MF has better
leave-one-out family assignment. In manuscript language, this means Twin-MF
improves ranking of the first correct family hit but does not create uniformly
cleaner clusters.
""")
gr.Image(value="data/benchmark/crispr_roc_esm2_vs_twin.png", label="Mixed reference ROC", show_label=True, container=False)
gr.Markdown("""
**Figure 1e. Pairwise ROC curves.** ROC curves evaluate all protein pairs rather
than query-level retrieval. Twin-MF has the best overall family AUC in Table 1,
whereas this ROC panel compares ESM2 and Twin-BP specifically. Pairwise AUC is
useful for global separation, but nearest-neighbour metrics are more directly
linked to annotation workflows.
""")
gr.Markdown("""
---
## Figure 2. Acr-only benchmark
The Acr-only reference asks a narrower question: can the embeddings recover
known anti-CRISPR family labels? This is harder than simple Cas-versus-Acr
separation because many Acr families are short, diverse, and defined by target
system rather than shared fold.
""")
gr.Image(value="data/benchmark/acrdb_v3/crispr_aspect_comparison.png", label="Acr-only model comparison", show_label=True, container=False)
gr.Markdown("""
**Figure 2a and Table 2. Anti-CRISPRdb v3 family benchmark.**
""")
gr.Image(value="data/benchmark/acrdb_v3/crispr_subcluster_comparison.png", label="Acr-only retrieval detail", show_label=True, container=False)
gr.Markdown("""
| model | within − between | 5-NN recall | LOO top-1 | LOO MRR | AUC family | AUC inhibited class | silhouette |
|---|---:|---:|---:|---:|---:|---:|---:|
| ESM2 baseline | 0.008 | **0.306** | **0.506** | **0.631** | 0.658 | 0.612 | **-0.022** |
| Twin-BP | 0.103 | 0.256 | 0.380 | 0.523 | 0.657 | 0.529 | -0.075 |
| Twin-CC | 0.157 | 0.294 | 0.418 | 0.557 | **0.708** | 0.577 | -0.040 |
| Twin-MF | **0.185** | 0.284 | 0.481 | 0.606 | 0.701 | **0.646** | -0.075 |
**Interpretation.** Acr-only retrieval currently favours ESM2. Twin-CC and
Twin-MF improve pairwise AUCs, but they do not improve the practical
nearest-neighbour task. This may reflect real biology: anti-CRISPR family names
often group proteins by inhibited CRISPR system rather than by a single
evolutionary or structural family. For Acr discovery, Twin scores may still be
useful as an additional signal, but the current benchmark does not justify
replacing ESM2 for Acr family lookup.
""")
gr.Markdown("""
---
## Figure 3. Cas-only benchmark
The Cas-only reference removes the easy Cas-versus-Acr split and tests
within-Cas subfamily resolution. Proteins were selected by Pfam cross-reference,
not by broad protein-name search, to reduce label noise.
""")
gr.Image(value="data/benchmark/cas_pfam/crispr_umap_esm2_vs_twin.png", label="Cas-only projection", show_label=True, container=False)
gr.Markdown("""
**Figure 3a. Cas-only two-dimensional projection.** The plot visualizes whether
Cas families occupy separate regions. It should be read together with Table 3,
because projection layout can exaggerate or hide neighbourhood structure.
""")
gr.Image(value="data/benchmark/cas_pfam/crispr_silhouette.png", label="Cas-only silhouette", show_label=True, container=False)
gr.Markdown("""
**Figure 3b. Cas-only silhouette.** ESM2 produces the best cluster compactness
on this reference, while Twin-MF performs best on top-hit family assignment.
This difference is important: compact global clusters and the identity of the
nearest correct family hit are related but not identical objectives.
""")
gr.Image(value="data/benchmark/cas_pfam/crispr_aspect_comparison.png", label="Cas-only model comparison", show_label=True, container=False)
gr.Markdown("""
**Table 3. Cas-only performance.**
| model | within − between | 5-NN recall | LOO top-1 | LOO MRR | AUC family | silhouette |
|---|---:|---:|---:|---:|---:|---:|
| ESM2 baseline | 0.033 | 0.763 | 0.834 | 0.887 | 0.798 | **0.221** |
| Twin-BP | 0.171 | **0.802** | 0.851 | 0.896 | **0.842** | 0.206 |
| Twin-CC | **0.407** | 0.610 | 0.749 | 0.824 | 0.791 | 0.093 |
| Twin-MF | 0.274 | 0.759 | **0.874** | **0.911** | 0.821 | 0.192 |
**Interpretation.** Twin-MF is the best model when the task is leave-one-out
Cas family assignment: the first correct-family neighbour appears earlier and
the nearest neighbour is correct more often. Twin-BP is best for 5-NN recall and
pairwise family AUC. ESM2 is best for silhouette. Twin-CC has high
within-minus-between cosine but poor retrieval, so it should not be used as the
default CRISPR aspect.
### Overall conclusion
For a manuscript, the defensible conclusion is nuanced:
- **Cas annotation:** Twin-MF is useful for top-hit family assignment; Twin-BP is
useful for broader neighbour recall.
- **Acr annotation:** ESM2 remains the stronger nearest-neighbour baseline on
the current Anti-CRISPRdb v3 reference.
- **Mixed CRISPR retrieval:** Twin-MF improves several ranking metrics, but ESM2
has better silhouette. Report both rather than claiming a universal win.
- **Aspect choice:** CC is consistently weak for retrieval; BP and MF are the
only plausible Twin aspects for CRISPR proteins.
The practical recommendation is to expose both ESM2 and Twin-MF/BP in the tool:
ESM2 as the conservative sequence-representation baseline, Twin-MF for Cas-like
top-hit annotation, and Twin-BP when broader same-family neighbour retrieval is
the goal.
""")
gr.Markdown("""
---
## Figure 4. Applying the models to unknown proteins
The benchmark above uses proteins with known Cas or Acr labels. As a more
realistic discovery-style check, we also screen uncharacterised proteins from a
larger phage/bacterial/archaeal candidate FASTA and score each query against the
same mixed Cas/Acr reference. This is not a validated discovery set; it is a
triage view that asks which unknown proteins land closest to known CRISPR
families under ESM2 versus Twin-MF. Because these proteins are unlabeled, the
screen does **not** decide which embedding is correct. The distribution plot
shows all scores, but the candidate list below reports only high-confidence
hits where the nearest-neighbour evidence is strongest.
""")
unknown_data = get_unknown_results()
n_unknown = unknown_data.get("n_queries", 0)
unknown_records = unknown_data.get("records", [])
if n_unknown > 0:
def _unknown_count_by_verdict(method_key):
pred = sum(1 for r in unknown_records if str(r[method_key]["verdict"]).startswith("Predicted"))
signal = sum(1 for r in unknown_records if r[method_key]["confidence"] != "low")
high = sum(1 for r in unknown_records if r[method_key]["confidence"] == "high")
return pred, signal, high
pred_mf, signal_mf, high_mf = _unknown_count_by_verdict("twin-mf")
pred_es, signal_es, high_es = _unknown_count_by_verdict("esm2")
gr.Markdown(
f"**Precomputed screen.** {n_unknown} uncharacterised proteins were embedded "
f"with ESM2 and Twin-MF, then ranked against the mixed 254-protein CRISPR "
f"reference.\n\n"
f"| method | family-level verdict | any above-threshold CRISPR signal | high-confidence calls |\n"
f"|---|---:|---:|---:|\n"
f"| Twin-MF | **{pred_mf}** / {n_unknown} | **{signal_mf}** / {n_unknown} | **{high_mf}** / {n_unknown} |\n"
f"| ESM2 baseline | {pred_es} / {n_unknown} | {signal_es} / {n_unknown} | {high_es} / {n_unknown} |\n\n"
f"The first two columns are diagnostic counts. Only the high-confidence "
f"candidates are listed in the table, with protein name and sequence, because "
f"medium-confidence calls are useful for exploration but too weak to present "
f"as report-level candidates. High cosine and agreement among top neighbours "
f"make a protein worth follow-up; they do not prove CRISPR function."
)
else:
gr.Markdown(
"**Precomputed screen unavailable.** Rebuild it with "
"`scripts/analyses/crispr/precompute_unknowns.py --ref_dir data/crispr_reference_mixed_v3`."
)
gr.Image(
value="data/benchmark/crispr_unknowns_distribution.png",
label="Unknown-protein screen score distribution",
show_label=True,
container=False,
)
with gr.Row():
unknown_method_radio = gr.Radio(
choices=["genomenet-twin (MF)", "ESM2 baseline"],
value="genomenet-twin (MF)",
label="Unknown-screen view",
info="Controls the candidate table and 3-D placement overview.",
)
unknown_load_btn = gr.Button("load unknown screen", variant="secondary")
unknown_plot = gr.Plot(label="3-D reference projection + unknown-protein candidates")
unknown_table = gr.HTML(
"<div style='padding:12px;color:#6b7280;'>Click <b>load unknown screen</b> to render the high-confidence candidate list.</div>"
)
def _make_unknown_screen_plot(method_choice):
method_key = "Twin-MF" if method_choice.startswith("genomenet") else "ESM2 baseline"
result_key = "twin-mf" if method_key == "Twin-MF" else "esm2"
xyz = get_crispr_umap_3d(method_key)
if xyz is None or not unknown_records:
return go.Figure()
_, ref_meta = get_crispr_reference(method_key)
fams = list(dict.fromkeys(ref_meta["family"].tolist()))
palette = [
"#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b",
"#e377c2", "#7f7f7f", "#bcbd22", "#17becf", "#aec7e8", "#ffbb78",
"#98df8a", "#ff9896", "#c5b0d5", "#c49c94", "#f7b6d3", "#c7c7c7",
"#dbdb8d", "#9edae5",
]
fam_colour = {f: palette[i % len(palette)] for i, f in enumerate(fams)}
traces = []
for fam in fams:
mask = (ref_meta["family"] == fam).values
sub = ref_meta[mask]
hover = [
f"<b>{html.escape(_crispr_display_id(r))}</b><br>"
f"{html.escape(_cell_text(r.get('name', '')))}<br>"
f"{html.escape(_cell_text(r.get('organism', '')))}<br>"
f"family: {html.escape(str(fam))}"
for _, r in sub.iterrows()
]
traces.append(go.Scatter3d(
x=xyz[mask, 0], y=xyz[mask, 1], z=xyz[mask, 2],
mode="markers",
marker=dict(size=4, color=fam_colour[fam], line=dict(width=0.5, color="black")),
name=fam, legendgroup="ref", legendgrouptitle_text="reference families",
hovertext=hover, hoverinfo="text",
))
ranked_records = [
r for r in sorted(unknown_records, key=lambda r: -r[result_key]["top_cosine"])
if r[result_key]["confidence"] == "high"
]
ux, uy, uz, htxt, ucol = [], [], [], [], []
conf_colour = {"low": "#dc2626", "medium": "#d97706", "high": "#059669"}
ref_ids = ref_meta["acc"].astype(str).tolist()
ref_id_to_idx = {rid: i for i, rid in enumerate(ref_ids)}
display_to_idx = {_crispr_display_id(row): i for i, row in ref_meta.iterrows()}
for r in ranked_records:
result = r[result_key]
idxs, top_sims = [], []
for t in result["top_10"][:5]:
tid = str(t["id"])
idx = ref_id_to_idx.get(tid, display_to_idx.get(tid))
if idx is not None:
idxs.append(idx)
top_sims.append(float(t["cosine"]))
if len(idxs) < 3:
continue
top_sims = np.array(top_sims, dtype=np.float32)
w = np.exp((top_sims - top_sims.max()) * 8.0)
w /= w.sum()
pos = (xyz[idxs] * w[:, None]).sum(axis=0)
ux.append(pos[0]); uy.append(pos[1]); uz.append(pos[2])
top1 = result["top_10"][0]
verdict_short = str(result["verdict"]).split(";")[0][:90]
htxt.append(
f"<b>{html.escape(str(r['acc']))}</b> ({result['confidence'].upper()})<br>"
f"{html.escape(str(r.get('organism', '')))}<br>"
f"<b>verdict:</b> {html.escape(verdict_short)}<br>"
f"top-1: {html.escape(str(top1['family']))} (cos = {result['top_cosine']:+.3f})"
)
ucol.append(conf_colour[result["confidence"]])
traces.append(go.Scatter3d(
x=ux, y=uy, z=uz, mode="markers",
marker=dict(size=6, symbol="diamond", color=ucol, line=dict(width=1.0, color="gold")),
name=f"high-confidence unknowns ({len(ux)})",
legendgroup="unknowns",
legendgrouptitle_text="unknown candidates",
hovertext=htxt, hoverinfo="text",
))
fig = go.Figure(data=traces)
fig.update_layout(
scene=dict(xaxis_title="projection-1", yaxis_title="projection-2", zaxis_title="projection-3"),
height=620, margin=dict(l=0, r=0, t=35, b=0),
title=f"Mixed CRISPR reference + high-confidence unknown candidates · {method_choice}",
legend=dict(itemsizing="constant", groupclick="toggleitem"),
)
return fig
def _make_unknown_screen_table(method_choice):
if not unknown_records:
return "<div style='padding:12px;color:#6b7280;'>No unknown-screen records packaged.</div>"
result_key = "twin-mf" if method_choice.startswith("genomenet") else "esm2"
recs_sorted = [
r for r in sorted(unknown_records, key=lambda r: -r[result_key]["top_cosine"])
if r[result_key]["confidence"] == "high"
]
if not recs_sorted:
return (
f"<div style='padding:12px;color:#6b7280;'>No high-confidence unknown-protein "
f"candidates for {html.escape(method_choice)} under the current thresholds. "
f"Use the distribution plot above to compare lower-confidence score behaviour.</div>"
)
seqs = get_unknown_sequences()
shown = recs_sorted
rows = []
for r in shown:
res = r[result_key]
conf = res["confidence"]
col = {"high": "#059669", "medium": "#d97706", "low": "#dc2626"}[conf]
top1 = res["top_10"][0]
verdict_short = html.escape(str(res["verdict"])[:110])
acc = html.escape(str(r["acc"]))
name = html.escape(str(r.get("name", ""))[:80])
org = html.escape(str(r.get("organism", ""))[:48])
source = html.escape(str(r.get("source", "")))
top_family = html.escape(str(top1["family"]))
seq = html.escape(_wrap_sequence_for_html(seqs.get(str(r["acc"]), ""), width=80))
sequence_block = (
f"<details><summary style='cursor:pointer;color:#2563eb;'>sequence</summary>"
f"<pre style='white-space:pre-wrap;word-break:break-word;font-size:11px;"
f"line-height:1.35;background:#f9fafb;border:1px solid #e5e7eb;"
f"border-radius:4px;padding:8px;margin:6px 0 0 0;'>{seq}</pre></details>"
if seq else ""
)
rows.append(
f"<tr>"
f"<td style='padding:6px 10px;font-family:ui-monospace,monospace;white-space:nowrap;'>"
f"<a href='{r['uniprot']}' target='_blank' style='color:#2563eb;text-decoration:none;'>{acc}</a></td>"
f"<td style='padding:6px 10px;font-size:12px;color:#374151;'>{name}{sequence_block}</td>"
f"<td style='padding:6px 10px;font-size:12px;color:#374151;'>{org}</td>"
f"<td style='padding:6px 10px;font-size:12px;color:#6b7280;'>{r['length']}</td>"
f"<td style='padding:6px 10px;font-size:12px;color:#6b7280;'>{source}</td>"
f"<td style='padding:6px 10px;'><span style='background:{col};color:white;"
f"padding:2px 6px;border-radius:3px;font-size:11px;'>{conf.upper()}</span></td>"
f"<td style='padding:6px 10px;font-size:12px;'>{verdict_short}</td>"
f"<td style='padding:6px 10px;font-family:ui-monospace,monospace;font-size:12px;'>"
f"{top_family} · {res['top_cosine']:+.3f}</td>"
f"</tr>"
)
return (
f"<h4 style='margin:12px 0 6px 0;'>High-confidence unknown-protein candidates "
f"({len(recs_sorted)} of {len(unknown_records)} screened; {method_choice}, sorted by top-1 cosine)</h4>"
f"<table style='width:100%;border-collapse:collapse;font-size:13px;'>"
f"<thead><tr style='text-align:left;border-bottom:1px solid #e5e7eb;color:#6b7280;font-size:12px;'>"
f"<th style='padding:6px 10px;'>UniProt</th>"
f"<th style='padding:6px 10px;'>name and sequence</th>"
f"<th style='padding:6px 10px;'>organism</th>"
f"<th style='padding:6px 10px;'>aa</th>"
f"<th style='padding:6px 10px;'>source</th>"
f"<th style='padding:6px 10px;'>confidence</th>"
f"<th style='padding:6px 10px;'>verdict</th>"
f"<th style='padding:6px 10px;'>top hit · cosine</th>"
f"</tr></thead><tbody>{''.join(rows)}</tbody></table>"
)
unknown_method_radio.change(
fn=lambda m: (_make_unknown_screen_plot(m), _make_unknown_screen_table(m)),
inputs=[unknown_method_radio],
outputs=[unknown_plot, unknown_table],
show_progress="minimal",
)
unknown_load_btn.click(
fn=lambda m: (_make_unknown_screen_plot(m), _make_unknown_screen_table(m)),
inputs=[unknown_method_radio],
outputs=[unknown_plot, unknown_table],
show_progress="minimal",
)
with gr.Tab("API"):
gr.Markdown("""
### API
```python
from gradio_client import Client
import numpy as np
client = Client("genomenet/functional-distance")
result = client.predict(
sequence="MALWMRLLPLLALLALWG...", # protein sequence
top_k=10,
twin_aspect="BP", # "BP" | "CC" | "MF"
api_name="/compare"
)
summary, esm2_path, twin_path, *plots, hits = result
esm2_emb = np.load(esm2_path) # (1280,)
twin_emb = np.load(twin_path) # (1024,)
# hits: DataFrame with columns [rank, uniref50_id, cosine, uniprot]
# Twin pairwise distance (the model's native trained task)
distance_md = client.predict(
seq_a="MALWMRLLPLLALLALWG...",
seq_b="MGKISSLPTQLFKCCFCDFL...",
aspect="BP", # "BP" | "CC" | "MF"
api_name="/distance",
)
```
### Nearest-neighbor search
The lookup tab can query FAISS indexes over the GO-annotated UniRef50 subset:
ESM2 uses 1280-dim ESM2 embeddings, and `genomenet-twin (BP)` uses 1024-dim
Twin-BP embeddings. Both indexes use cosine similarity on L2-normalized vectors.
Top-k hits link to the corresponding UniRef50 cluster on UniProt.
### Models
| Model | Dimension | Description |
|-------|-----------|-------------|
| ESM2 | 1280 | `esm2_t33_650M_UR50D` pretrained on UniRef50 |
| Twin | 1024 | Resnik-contrastive fine-tune; one checkpoint per GO aspect (BP/CC/MF) |
### Comparison
The Twin model is trained to produce embeddings where functionally similar proteins
(sharing GO terms) have similar embeddings. ESM2 is the pretrained baseline without
this GO supervision.
""")
with gr.Tab("About"):
gr.Markdown("""
### ESM2 vs Twin
**ESM2** (`esm2_t33_650M_UR50D`):
- 650M parameter protein language model
- Pretrained on UniRef50 with masked language modeling
- General-purpose protein representation
**Twin Network** (`train_point_{BP,CC,MF}_20251221_std_ft_bs32ga4`):
- Two-tower contrastive encoder: custom AA Transformer + **fine-tuned** ESM2 backbone
- **One checkpoint per GO aspect**: Biological Process (BP), Cellular Component (CC),
Molecular Function (MF). Pick aspect via the Twin GO aspect radio button.
- Trained on Resnik GO-semantic similarity within each aspect
- Output: `concat(custom_proj, esm_proj)` → 1024-dim; L2 distance on L2-normalized
embeddings ≈ functional distance in that aspect
### Gene Ontology
GO has three aspects:
- **MF** (Molecular Function): Biochemical activity
- **BP** (Biological Process): Biological objective
- **CC** (Cellular Component): Subcellular location
### Links
- ESM2: [github.com/facebookresearch/esm](https://github.com/facebookresearch/esm)
- BERT DNA: [genomenet/bert-embedding](https://huggingface.co/spaces/genomenet/bert-embedding)
""")
if __name__ == "__main__":
print("Functional Distance - ESM2 vs Twin")
print(f"Device: {get_device()}")
print("Loading ESM2...")
try:
_ = get_esm2()
print("ESM2 ready!")
except Exception as e:
print(f"ESM2 load failed (will load on first request): {e}")
print(f"Loading ESM2 FAISS index from {ESM2_FAISS_REPO_ID}...")
try:
_ = get_faiss("esm2")
except Exception as e:
print(f"ESM2 FAISS load failed (will retry on first request): {e}")
print(f"Loading default Twin aspect ({TWIN_DEFAULT_ASPECT}) from {TWIN_REPO_ID}...")
try:
_ = get_twin(TWIN_DEFAULT_ASPECT)
print(f"Twin/{TWIN_DEFAULT_ASPECT} ready!")
except Exception as e:
print(f"Twin load failed (will retry on first request): {e}")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
allowed_paths=[
os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "benchmark"),
],
theme=gr.themes.Base(
primary_hue=gr.themes.colors.zinc,
neutral_hue=gr.themes.colors.zinc,
)
)