Spaces:
Sleeping
Sleeping
| """ | |
| Functional Distance - Compare ESM2 vs Twin protein embeddings | |
| HuggingFace Spaces App | |
| """ | |
| import os | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| import html | |
| import tempfile | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| # Model config | |
| ESM2_MODEL = "esm2_t33_650M_UR50D" # 650M params, 1280-dim | |
| ESM2_DIM = 1280 | |
| TWIN_DIM = 1024 # 2 * projection_dim (512), two-tower concat | |
| # FAISS index config (UniRef50 GO-annotated protein clusters) | |
| ESM2_FAISS_REPO_ID = os.environ.get("ESM2_FAISS_REPO_ID", os.environ.get("FAISS_REPO_ID", "genomenet/esm2-uniref50-faiss")) | |
| TWIN_BP_FAISS_REPO_ID = os.environ.get("TWIN_BP_FAISS_REPO_ID", "genomenet/twin-uniref50-faiss") | |
| FAISS_REPO_ID = ESM2_FAISS_REPO_ID # backwards-compatible name used in older logs/env configs | |
| FAISS_IDS_FILE = "ids.npy" | |
| FAISS_METADATA_FILE = "metadata.json" | |
| FAISS_NPROBE = int(os.environ.get("FAISS_NPROBE", "32")) | |
| FAISS_CONFIGS = { | |
| "esm2": { | |
| "repo_id": ESM2_FAISS_REPO_ID, | |
| "index_file": "esm2_uniref50.index", | |
| "label": "ESM2 baseline", | |
| "dim_info": "1280-dim ESM2", | |
| }, | |
| "twin-bp": { | |
| "repo_id": TWIN_BP_FAISS_REPO_ID, | |
| "index_file": "twin_uniref50.index", | |
| "label": "genomenet-twin (BP)", | |
| "dim_info": "1024-dim Twin-BP", | |
| }, | |
| } | |
| # Twin model config (3 aspect-specific checkpoints in one HF model repo) | |
| TWIN_REPO_ID = os.environ.get("TWIN_REPO_ID", "genomenet/twin-point-1024") | |
| TWIN_CHECKPOINT_FILES = { | |
| "BP": "bp_cp_best.pt", # Biological Process | |
| "CC": "cc_cp_best.pt", # Cellular Component | |
| "MF": "mf_cp_best.pt", # Molecular Function | |
| } | |
| TWIN_DEFAULT_ASPECT = os.environ.get("TWIN_DEFAULT_ASPECT", "BP") | |
| TWIN_ESM_BACKBONE = os.environ.get("TWIN_ESM_BACKBONE", "facebook/esm2_t33_650M_UR50D") | |
| # Model cache | |
| _esm2_model = None | |
| _esm2_alphabet = None | |
| # Only one aspect cached at a time (each Twin is ~2.7 GB on CPU, can't fit all 3 on a cpu-basic Space) | |
| _twin_cache = {"aspect": None, "model": None, "seq_len": None} | |
| _faiss_cache = {"name": None, "index": None, "ids": None, "metadata": None} | |
| def get_esm2(): | |
| """Load ESM2 model.""" | |
| global _esm2_model, _esm2_alphabet | |
| if _esm2_model is None: | |
| import esm | |
| print(f"Loading ESM2 model: {ESM2_MODEL}...") | |
| _esm2_model, _esm2_alphabet = esm.pretrained.load_model_and_alphabet(ESM2_MODEL) | |
| _esm2_model = _esm2_model.eval() | |
| if torch.cuda.is_available(): | |
| _esm2_model = _esm2_model.cuda() | |
| print("ESM2 loaded.") | |
| return _esm2_model, _esm2_alphabet | |
| def get_twin(aspect=None): | |
| """Download + load the fine-tuned Twin model for the requested GO aspect. | |
| Only one aspect is kept in memory at a time — switching aspects evicts the | |
| previous model (each is ~2.7 GB; three won't fit on cpu-basic). | |
| """ | |
| global _twin_cache | |
| aspect = (aspect or TWIN_DEFAULT_ASPECT).upper() | |
| if aspect not in TWIN_CHECKPOINT_FILES: | |
| raise ValueError(f"Unknown aspect {aspect!r}; expected one of {list(TWIN_CHECKPOINT_FILES)}") | |
| if _twin_cache["aspect"] == aspect and _twin_cache["model"] is not None: | |
| return _twin_cache["model"], _twin_cache["seq_len"] | |
| import gc | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from twin_model import load_twin_model | |
| # Evict any previously loaded aspect to free ~2.7 GB before loading the next. | |
| if _twin_cache["model"] is not None: | |
| print(f"Evicting Twin/{_twin_cache['aspect']} to load Twin/{aspect}...") | |
| _twin_cache["model"] = None | |
| _twin_cache["seq_len"] = None | |
| _twin_cache["aspect"] = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| filename = TWIN_CHECKPOINT_FILES[aspect] | |
| print(f"Downloading Twin/{aspect} checkpoint ({filename}) from {TWIN_REPO_ID}...") | |
| ckpt_path = hf_hub_download(repo_id=TWIN_REPO_ID, filename=filename, repo_type="model") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model, seq_len, emb_dim = load_twin_model(ckpt_path, device, TWIN_ESM_BACKBONE) | |
| if emb_dim != TWIN_DIM: | |
| print(f" WARN Twin/{aspect} output dim is {emb_dim}, expected {TWIN_DIM}") | |
| _twin_cache["aspect"] = aspect | |
| _twin_cache["model"] = model | |
| _twin_cache["seq_len"] = seq_len | |
| return model, seq_len | |
| def get_faiss(index_name="esm2"): | |
| """Download + load the requested FAISS index and UniRef50 id mapping.""" | |
| global _faiss_cache | |
| if _faiss_cache["name"] == index_name and _faiss_cache["index"] is not None: | |
| return _faiss_cache["index"], _faiss_cache["ids"], _faiss_cache["metadata"] | |
| if index_name not in FAISS_CONFIGS: | |
| raise ValueError(f"Unknown FAISS index {index_name!r}; expected one of {list(FAISS_CONFIGS)}") | |
| import gc | |
| import faiss | |
| import json | |
| from huggingface_hub import snapshot_download | |
| if _faiss_cache["index"] is not None: | |
| print(f"Evicting FAISS/{_faiss_cache['name']} to load FAISS/{index_name}...") | |
| _faiss_cache = {"name": None, "index": None, "ids": None, "metadata": None} | |
| gc.collect() | |
| cfg = FAISS_CONFIGS[index_name] | |
| print(f"Downloading FAISS/{index_name} index from {cfg['repo_id']}...") | |
| local_dir = snapshot_download( | |
| repo_id=cfg["repo_id"], | |
| repo_type="dataset", | |
| allow_patterns=[cfg["index_file"], FAISS_IDS_FILE, FAISS_METADATA_FILE], | |
| ) | |
| print(f"Loading FAISS index from {local_dir}...") | |
| index = faiss.read_index(os.path.join(local_dir, cfg["index_file"])) | |
| try: | |
| ivf = faiss.extract_index_ivf(index) | |
| ivf.nprobe = FAISS_NPROBE | |
| nprobe = ivf.nprobe | |
| except Exception: | |
| nprobe = "n/a" | |
| ids = np.load(os.path.join(local_dir, FAISS_IDS_FILE)) | |
| metadata = None | |
| meta_path = os.path.join(local_dir, FAISS_METADATA_FILE) | |
| if os.path.exists(meta_path): | |
| with open(meta_path) as f: | |
| metadata = json.load(f) | |
| _faiss_cache = {"name": index_name, "index": index, "ids": ids, "metadata": metadata} | |
| print(f"FAISS/{index_name} ready: {index.ntotal:,} vectors, nprobe={nprobe}") | |
| return index, ids, metadata | |
| def get_device(): | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| # Valid amino acids | |
| AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY") | |
| AMINO_ACIDS_EXTENDED = AMINO_ACIDS | set("XBZJOU") # Include ambiguous | |
| ESM2_MAX_LEN = 1022 # ESM2 position embedding limit; longer sequences are truncated | |
| def validate_protein(sequence): | |
| """Validate protein sequence. Does NOT reject long sequences — both embedders | |
| truncate internally; we surface a note in the output instead.""" | |
| if not sequence or len(sequence.strip()) == 0: | |
| return False, "Sequence is empty" | |
| sequence = sequence.upper().replace(" ", "").replace("\n", "") | |
| invalid = set(sequence) - AMINO_ACIDS_EXTENDED | |
| if invalid: | |
| return False, f"Invalid characters: {invalid}" | |
| if len(sequence) < 10: | |
| return False, f"Sequence too short: {len(sequence)} < 10 aa" | |
| return True, "" | |
| def strip_fasta_header(text): | |
| """Remove FASTA headers.""" | |
| lines = text.strip().split('\n') | |
| return ''.join(l.strip() for l in lines if not l.startswith('>')).upper() | |
| def embed_esm2(sequence): | |
| """Compute ESM2 embedding (mean-pooled). Truncates to ESM2_MAX_LEN.""" | |
| model, alphabet = get_esm2() | |
| batch_converter = alphabet.get_batch_converter() | |
| device = get_device() | |
| # ESM2 position embeddings cap at 1022; longer sequences must be truncated. | |
| if len(sequence) > ESM2_MAX_LEN: | |
| sequence = sequence[:ESM2_MAX_LEN] | |
| data = [("protein", sequence)] | |
| _, _, batch_tokens = batch_converter(data) | |
| batch_tokens = batch_tokens.to(device) | |
| results = model(batch_tokens, repr_layers=[model.num_layers], return_contacts=False) | |
| representations = results["representations"][model.num_layers] | |
| # Mean pool over sequence (exclude BOS/EOS) | |
| seq_len = len(sequence) | |
| embedding = representations[0, 1:seq_len+1, :].mean(dim=0).cpu().numpy() | |
| return embedding | |
| def embed_twin(sequence, aspect=None): | |
| """Compute Twin embedding for the given GO aspect (BP/CC/MF).""" | |
| from twin_model import ensure_aa_sequence, preprocess_sequences_batch | |
| model, seq_len = get_twin(aspect) | |
| device = next(model.parameters()).device | |
| cleaned = ensure_aa_sequence(sequence) | |
| input_ids = preprocess_sequences_batch([cleaned], seq_len, device) | |
| combined = model(input_ids) # (1, TWIN_DIM) | |
| return combined[0].float().cpu().numpy() | |
| def compute_distance(seq_a, seq_b, aspect): | |
| """Twin-model pairwise distance (native trained task). | |
| Returns a dict with L2 distance on L2-normalized embeddings (training | |
| convention), cosine similarity, and cosine distance. | |
| """ | |
| import torch.nn.functional as F | |
| from twin_model import ensure_aa_sequence, preprocess_sequences_batch | |
| model, seq_len = get_twin(aspect) | |
| device = next(model.parameters()).device | |
| seqs = [ensure_aa_sequence(seq_a), ensure_aa_sequence(seq_b)] | |
| input_ids = preprocess_sequences_batch(seqs, seq_len, device) | |
| emb = model(input_ids) # (2, TWIN_DIM) | |
| a = F.normalize(emb[0:1], p=2, dim=-1) | |
| b = F.normalize(emb[1:2], p=2, dim=-1) | |
| cos_sim = float((a * b).sum().item()) | |
| l2 = float(torch.norm(a - b, p=2, dim=-1).item()) | |
| return {"l2": l2, "cos_sim": cos_sim, "cos_dist": 1.0 - cos_sim} | |
| _CRISPR_REF_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "crispr_reference") | |
| _CRISPR_FILES = { | |
| "Twin-BP": "embeddings.npy", # 1024-dim, from Step 6 Twin-BP build | |
| "Twin-MF": "embeddings_mf.npy", # 1024-dim, best-fit aspect for CRISPR | |
| "ESM2 baseline": "embeddings_esm2.npy", # 1280-dim, ESM2 mean-pool | |
| } | |
| _crispr_ref = {m: {"emb": None, "meta": None} for m in _CRISPR_FILES} | |
| def get_crispr_reference(method="Twin-BP"): | |
| """Load the curated CRISPR reference embeddings + metadata (lazy, cached per method).""" | |
| if _crispr_ref[method]["emb"] is not None: | |
| return _crispr_ref[method]["emb"], _crispr_ref[method]["meta"] | |
| emb_path = os.path.join(_CRISPR_REF_DIR, _CRISPR_FILES[method]) | |
| meta_path = os.path.join(_CRISPR_REF_DIR, "metadata.csv") | |
| if not (os.path.exists(emb_path) and os.path.exists(meta_path)): | |
| raise FileNotFoundError(f"CRISPR reference ({method}) not packaged at {_CRISPR_REF_DIR}") | |
| emb = np.load(emb_path).astype(np.float32) | |
| # L2-normalize once so cosine similarity = dot product | |
| norms = np.linalg.norm(emb, axis=1, keepdims=True) + 1e-9 | |
| _crispr_ref[method]["emb"] = emb / norms | |
| _crispr_ref[method]["meta"] = pd.read_csv(meta_path) | |
| print(f"CRISPR reference ({method}) loaded: {emb.shape[0]} proteins, dim={emb.shape[1]}") | |
| return _crispr_ref[method]["emb"], _crispr_ref[method]["meta"] | |
| def _cell_text(value): | |
| if value is None: | |
| return "" | |
| try: | |
| if pd.isna(value): | |
| return "" | |
| except Exception: | |
| pass | |
| return str(value) | |
| def _crispr_display_id(row): | |
| """Readable ID for mixed Cas/Acr rows, whose internal ids are prefixed.""" | |
| source_acc = _cell_text(row.get("source_accession", "")) | |
| original = _cell_text(row.get("original_acc", "")) | |
| acc = _cell_text(row.get("acc", "")) | |
| if source_acc: | |
| return source_acc | |
| if original: | |
| return original.replace("cas__", "").replace("acr__", "") | |
| return acc.replace("cas__", "").replace("acr__", "") | |
| def _crispr_protein_link(row): | |
| source_acc = _cell_text(row.get("source_accession", "")) | |
| original = _cell_text(row.get("original_acc", "")) | |
| acc = _cell_text(row.get("acc", "")) | |
| clean = source_acc or original or acc | |
| clean = clean.replace("cas__", "").replace("acr__", "") | |
| if "__" in clean: | |
| clean = clean.split("__")[-1] | |
| if not clean or clean.startswith("local__"): | |
| return "" | |
| if clean.startswith(("WP_", "YP_", "NP_", "XP_", "AKG", "ERJ")) or clean.endswith(".1"): | |
| return f"https://www.ncbi.nlm.nih.gov/protein/{clean}" | |
| return f"https://www.uniprot.org/uniprotkb/{clean}" | |
| def search_crispr(query_emb, k=25, negate=False, method="Twin-BP"): | |
| """Cosine search over the curated CRISPR reference for the chosen method.""" | |
| ref_emb, meta = get_crispr_reference(method) | |
| q = np.asarray(query_emb, dtype=np.float32).reshape(-1) | |
| q = q / (np.linalg.norm(q) + 1e-9) | |
| sims = ref_emb @ q # (N,) | |
| if negate: | |
| order = np.argsort(sims)[:k] | |
| else: | |
| order = np.argsort(-sims)[:k] | |
| rows = [] | |
| for rank, idx in enumerate(order, 1): | |
| m = meta.iloc[idx] | |
| family = _cell_text(m.get("family", "")) | |
| typ = _cell_text(m.get("type", "")) | |
| organism = _cell_text(m.get("organism", "")) | |
| rows.append({ | |
| "rank": rank, | |
| "uniref50_id": _crispr_display_id(m), | |
| "ref_id": m["acc"], | |
| "cosine": round(float(sims[idx]), 4), | |
| "description": f"{family} · {typ} · {organism}", | |
| "uniprot": _crispr_protein_link(m), | |
| }) | |
| return pd.DataFrame(rows) | |
| _CRISPR_UMAP_FILES = { | |
| "ESM2 baseline": "umap_coords_esm2.npy", | |
| "Twin-MF": "umap_coords_mf.npy", | |
| "Twin-BP": "umap_coords_bp.npy", | |
| } | |
| _CRISPR_UMAP3D_FILES = { | |
| "ESM2 baseline": "umap_coords_3d_esm2.npy", | |
| "Twin-MF": "umap_coords_3d_mf.npy", | |
| "Twin-BP": "umap_coords_3d_bp.npy", | |
| } | |
| _crispr_umap_cache = {} | |
| _crispr_umap3d_cache = {} | |
| def get_crispr_umap(method): | |
| """Load pre-computed UMAP 2D coordinates for the CRISPR reference set.""" | |
| if method in _crispr_umap_cache: | |
| return _crispr_umap_cache[method] | |
| fname = _CRISPR_UMAP_FILES.get(method) | |
| if fname is None: | |
| return None | |
| p = os.path.join(_CRISPR_REF_DIR, fname) | |
| if not os.path.exists(p): | |
| return None | |
| xy = np.load(p).astype(np.float32) | |
| _crispr_umap_cache[method] = xy | |
| return xy | |
| def get_crispr_umap_3d(method): | |
| """Load pre-computed UMAP 3D coordinates for the CRISPR reference set.""" | |
| if method in _crispr_umap3d_cache: | |
| return _crispr_umap3d_cache[method] | |
| fname = _CRISPR_UMAP3D_FILES.get(method) | |
| if fname is None: | |
| return None | |
| p = os.path.join(_CRISPR_REF_DIR, fname) | |
| if not os.path.exists(p): | |
| return None | |
| xyz = np.load(p).astype(np.float32) | |
| _crispr_umap3d_cache[method] = xyz | |
| return xyz | |
| _UNKNOWN_RESULTS_PATH = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), "data", "unknown_proteins", "results.json" | |
| ) | |
| _UNKNOWN_FASTA_PATH = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), "data", "unknown_proteins", "sequences.fasta" | |
| ) | |
| _unknown_results = None | |
| _unknown_sequences = None | |
| def get_unknown_results(): | |
| """Load the pre-computed characterization results for uncharacterized | |
| proteins (see scripts/analyses/crispr/precompute_unknowns.py).""" | |
| global _unknown_results | |
| if _unknown_results is not None: | |
| return _unknown_results | |
| import json | |
| if not os.path.exists(_UNKNOWN_RESULTS_PATH): | |
| _unknown_results = {"n_queries": 0, "records": [], "generated_at": None} | |
| else: | |
| with open(_UNKNOWN_RESULTS_PATH) as f: | |
| _unknown_results = json.load(f) | |
| return _unknown_results | |
| def get_unknown_sequences(): | |
| """Load unknown-protein query sequences by FASTA id for candidate tables.""" | |
| global _unknown_sequences | |
| if _unknown_sequences is not None: | |
| return _unknown_sequences | |
| seqs = {} | |
| if not os.path.exists(_UNKNOWN_FASTA_PATH): | |
| _unknown_sequences = seqs | |
| return _unknown_sequences | |
| current = None | |
| chunks = [] | |
| with open(_UNKNOWN_FASTA_PATH) as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if line.startswith(">"): | |
| if current is not None: | |
| seqs[current] = "".join(chunks) | |
| current = line[1:].split(maxsplit=1)[0] | |
| chunks = [] | |
| else: | |
| chunks.append(line) | |
| if current is not None: | |
| seqs[current] = "".join(chunks) | |
| _unknown_sequences = seqs | |
| return _unknown_sequences | |
| def _wrap_sequence_for_html(seq, width=80): | |
| seq = seq or "" | |
| return "\n".join(seq[i:i + width] for i in range(0, len(seq), width)) | |
| def _crispr_verdict(top_k_rows, top_cos, threshold): | |
| """Return (label, confidence) given ranked hits.""" | |
| from collections import Counter | |
| if not top_k_rows: | |
| return "Unknown — no hits", "low" | |
| if top_cos < threshold: | |
| return (f"Unlikely CRISPR-related (top cosine {top_cos:+.3f} < threshold " | |
| f"{threshold:.2f})"), "low" | |
| fams = [r["family"] for r in top_k_rows] | |
| c = Counter(fams) | |
| top_fam, top_votes = c.most_common(1)[0] | |
| ratio = top_votes / len(fams) | |
| if ratio >= 0.6: | |
| conf = "high" if (ratio >= 0.8 and top_cos > threshold + 0.15) else "medium" | |
| return f"Predicted: **{top_fam}** ({top_votes}/{len(fams)} of top hits)", conf | |
| groups = [r["group"] for r in top_k_rows] | |
| gc = Counter(groups) | |
| top_group, top_gvotes = gc.most_common(1)[0] | |
| if top_gvotes / len(groups) >= 0.7: | |
| tag = "Cas protein" if top_group == "cas" else "anti-CRISPR protein" | |
| return (f"CRISPR-related: **{tag}**; specific family unclear " | |
| f"(top mix: {dict(c.most_common(3))})"), "medium" | |
| return (f"CRISPR-related, but signal is mixed " | |
| f"(top family votes: {dict(c.most_common(3))})"), "medium" | |
| _uniref_meta_cache = {} | |
| def fetch_uniref_metadata(uniref_ids): | |
| """Fetch cluster name + representative organism for a list of UniRef50 IDs. | |
| Uses the UniProt uniref endpoint. Results are cached in memory for the life | |
| of the Space process. Returns dict: {id -> "protein name — organism"}. | |
| Falls back to "" for any id that cannot be fetched. | |
| """ | |
| import requests | |
| out = {} | |
| missing = [] | |
| for uid in uniref_ids: | |
| if uid in _uniref_meta_cache: | |
| out[uid] = _uniref_meta_cache[uid] | |
| else: | |
| missing.append(uid) | |
| for uid in missing: | |
| desc = "(no metadata)" | |
| try: | |
| r = requests.get( | |
| f"https://rest.uniprot.org/uniref/{uid}.json", | |
| params={"fields": "name,organism,count"}, | |
| timeout=5, | |
| ) | |
| if r.status_code == 200: | |
| data = r.json() | |
| name = (data.get("name") or "").replace("Cluster: ", "") | |
| rep = data.get("representativeMember", {}) or {} | |
| org = (rep.get("organismName") or "").strip() | |
| count = data.get("memberCount") or "" | |
| parts = [p for p in (name, org, (f"{count} members" if count else "")) if p] | |
| desc = " — ".join(parts) or "(no metadata)" | |
| elif r.status_code == 404 and uid.startswith("UniRef50_UPI"): | |
| # UniParc representative — no UniProtKB entry, but UniParc has xrefs with | |
| # protein name + organism from RefSeq/EMBL/etc. | |
| upi = uid[len("UniRef50_"):] | |
| r2 = requests.get( | |
| f"https://rest.uniprot.org/uniparc/{upi}.json", timeout=5 | |
| ) | |
| if r2.status_code == 200: | |
| xrefs = (r2.json() or {}).get("uniParcCrossReferences", []) or [] | |
| name = org = "" | |
| for cr in xrefs: | |
| if not name and cr.get("proteinName"): | |
| name = cr["proteinName"] | |
| if not org and (cr.get("organism") or {}).get("scientificName"): | |
| org = cr["organism"]["scientificName"] | |
| if name and org: | |
| break | |
| parts = [p for p in (name, org, "(UniParc)") if p] | |
| desc = " — ".join(parts) if (name or org) else "(UniParc entry — no annotation)" | |
| else: | |
| desc = "(UniParc fetch failed)" | |
| else: | |
| desc = "(not in UniRef — cluster may have been updated)" | |
| except Exception as e: | |
| desc = f"(fetch error: {str(e).splitlines()[0][:60]})" | |
| _uniref_meta_cache[uid] = desc | |
| out[uid] = desc | |
| return out | |
| def search_faiss(query_embedding, k=10, negate=False, fetch_metadata=True, index_name="esm2"): | |
| """Search FAISS index for top-k UniRef50 neighbors. | |
| If `negate` is True, searches with -q: returns the most anti-correlated | |
| (least similar) proteins. Returned cosine values are still reported with | |
| respect to the original q (i.e., we flip the sign after search). | |
| When `fetch_metadata` is False, skips the per-hit UniProt API call (useful | |
| when you need a larger pool of cosine values for a distribution plot). | |
| """ | |
| import faiss | |
| index, ids, _ = get_faiss(index_name) | |
| q = np.asarray(query_embedding, dtype=np.float32)[None, :] | |
| faiss.normalize_L2(q) | |
| if negate: | |
| q = -q | |
| scores, idxs = index.search(q, k) | |
| hit_rows = [] | |
| for rank, (score, i) in enumerate(zip(scores[0], idxs[0]), 1): | |
| if i < 0: | |
| continue | |
| uid = ids[i].decode() if isinstance(ids[i], (bytes, np.bytes_)) else str(ids[i]) | |
| cos_real = -float(score) if negate else float(score) | |
| hit_rows.append((rank, uid, cos_real)) | |
| meta = fetch_uniref_metadata([uid for _, uid, _ in hit_rows]) if fetch_metadata else {} | |
| rows = [] | |
| for rank, uid, cos in hit_rows: | |
| rows.append({ | |
| "rank": rank, | |
| "uniref50_id": uid, | |
| "cosine": round(cos, 4), | |
| "description": meta.get(uid, ""), | |
| "uniprot": f"https://www.uniprot.org/uniref/{uid}", | |
| }) | |
| return pd.DataFrame(rows) | |
| def compute_stats(embedding): | |
| """Compute embedding statistics.""" | |
| emb = np.array(embedding) | |
| l2_norm = np.linalg.norm(emb) | |
| mean_act = np.mean(emb) | |
| std_act = np.std(emb) | |
| sparsity = np.mean(np.abs(emb) < 0.1) | |
| hist, _ = np.histogram(emb, bins=50, density=True) | |
| hist = hist[hist > 0] | |
| entropy = -np.sum(hist * np.log(hist + 1e-10)) | |
| return { | |
| 'l2_norm': float(l2_norm), | |
| 'mean': float(mean_act), | |
| 'std': float(std_act), | |
| 'sparsity': float(sparsity), | |
| 'entropy': float(entropy), | |
| 'dim': len(emb) | |
| } | |
| def create_embedding_heatmap(embedding, title, cols=32): | |
| """Create heatmap of embedding.""" | |
| n_dims = len(embedding) | |
| rows = int(np.ceil(n_dims / cols)) | |
| padded = np.full(rows * cols, np.nan) | |
| padded[:n_dims] = embedding | |
| grid = padded.reshape(rows, cols) | |
| vmax = max(abs(np.nanmin(embedding)), abs(np.nanmax(embedding)), 0.01) | |
| fig, ax = plt.subplots(figsize=(10, max(3, rows * 0.3))) | |
| im = ax.imshow(grid, cmap='RdBu_r', vmin=-vmax, vmax=vmax, aspect='auto') | |
| plt.colorbar(im, ax=ax, shrink=0.8) | |
| ax.set_title(f'{title} ({n_dims} dims)') | |
| ax.set_xlabel('Dimension') | |
| plt.tight_layout() | |
| return fig | |
| def create_comparison_plot(esm2_stats, twin_stats): | |
| """Create side-by-side stats comparison.""" | |
| metrics = ['L2 Norm', 'Mean', 'Std', 'Sparsity', 'Entropy'] | |
| esm2_vals = [esm2_stats['l2_norm'], esm2_stats['mean'], esm2_stats['std'], | |
| esm2_stats['sparsity'], esm2_stats['entropy']] | |
| twin_vals = [twin_stats['l2_norm'], twin_stats['mean'], twin_stats['std'], | |
| twin_stats['sparsity'], twin_stats['entropy']] | |
| fig = make_subplots(rows=1, cols=1) | |
| fig.add_trace(go.Bar( | |
| name='ESM2', | |
| x=metrics, | |
| y=esm2_vals, | |
| marker_color='#3b82f6', | |
| text=[f'{v:.3f}' for v in esm2_vals], | |
| textposition='outside' | |
| )) | |
| fig.add_trace(go.Bar( | |
| name='Twin', | |
| x=metrics, | |
| y=twin_vals, | |
| marker_color='#10b981', | |
| text=[f'{v:.3f}' for v in twin_vals], | |
| textposition='outside' | |
| )) | |
| fig.update_layout( | |
| barmode='group', | |
| height=350, | |
| legend=dict(orientation='h', y=1.1), | |
| margin=dict(l=40, r=20, t=50, b=40) | |
| ) | |
| return fig | |
| def create_distribution_plot(esm2_emb, twin_emb): | |
| """Compare activation distributions.""" | |
| fig = go.Figure() | |
| fig.add_trace(go.Histogram( | |
| x=esm2_emb, | |
| name='ESM2', | |
| opacity=0.7, | |
| marker_color='#3b82f6', | |
| nbinsx=50 | |
| )) | |
| fig.add_trace(go.Histogram( | |
| x=twin_emb, | |
| name='Twin', | |
| opacity=0.7, | |
| marker_color='#10b981', | |
| nbinsx=50 | |
| )) | |
| fig.update_layout( | |
| barmode='overlay', | |
| xaxis_title='Activation value', | |
| yaxis_title='Count', | |
| height=300, | |
| legend=dict(orientation='h', y=1.1), | |
| margin=dict(l=40, r=20, t=40, b=40) | |
| ) | |
| return fig | |
| # CRISPR-associated + anti-CRISPR sequences (UniProt, reviewed/representative). | |
| # This app is framed around CRISPR biology: Twin is a general GO-contrastive model, | |
| # but the intended use case is annotation of Cas and anti-CRISPR (Acr) genes. | |
| # Sequences longer than 1022 aa (SpCas9, SaCas9, FnCas12a, LshCas13a) are truncated | |
| # by the model's seq_len=1024; the N-terminal region still carries most of the signal. | |
| _CAS1_ECOLI = ( # Q46901, E. coli, 502 aa — adaptation, universal | |
| "MNLLIDNWIPVRPRNGGKVQIINLQSLYCSRDQWRLSLPRDDMELAALALLVCIGQIIAPAKDDVEFRHRIMNPLTEDEF" | |
| "QQLIAPWIDMFYLNHAEHPFMQTKGVKANDVTPMEKLLAGVSGATNCAFVNQPGQGEALCGGCTAIALFNQANQAPGFGG" | |
| "GFKSGLRGGTPVTTFVRGIDLRSTVLLNVLTLPRLQKQFPNESHTENQPTWIKPIKSNESIPASSIGFVRGLFWQPAHIE" | |
| "LCDPIGIGKCSCCGQESNLRYTGFLKEKFTFTVNGLWPHPHSPCLVTVKKGEVEEKFLAFTTSAPSWTQISRVVVDKIIQ" | |
| "NENGNRVAAVVNQFRNIAPQSPLELIMGGYRNNQASILERRHDVLMFNQGWQQYGNVINEIVTVGLGYKTALRKALYTFA" | |
| "EGFKNKDFKGAGVSVHETAERHFYRQSELLIPDVLANVNFSQADEVIADLRDKLHQLCEMLFNQSVAPYAHHPKLISTLA" | |
| "LARATLYKHLRELKPQGGPSNG" | |
| ) | |
| _CAS2_ECOLI = ( # P45956, E. coli, 94 aa — adaptation, universal | |
| "MSMLVVVTENVPPRLRGRLAIWLLEVRAGVYVGDVSAKIREMIWEQIAGLAEEGNVVMAWATNTETGFEFQTFGLNRRTP" | |
| "VDLDGLRLVSFLPV" | |
| ) | |
| _CAS3_ECOLI = ( # P38036, E. coli, 888 aa — Type I interference nuclease/helicase | |
| "MEPFKYICHYWGKSSKSLTKGNDIHLLIYHCLDVAAVADCWWDQSVVLQNTFCRNEMLSKQRVKAWLLFFIALHDIGKFD" | |
| "IRFQYKSAESWLKLNPATPSLNGPSTQMCRKFNHGAAGLYWFNQDSLSEQSLGDFFSFFDAAPHPYESWFPWVEAVTGHH" | |
| "GFILHSQDQDKSRWEMPASLASYAAQDKQAREEWISVLEALFLTPAGLSINDIPPDCSSLLAGFCSLADWLGSWTTTNTF" | |
| "LFNEDAPSDINALRTYFQDRQQDASRVLELSGLVSNKRCYEGVHALLDNGYQPRQLQVLVDALPVAPGLTVIEAPTGSGK" | |
| "TETALAYAWKLIDQQIADSVIFALPTQATANAMLTRMEASASHLFSSPNLILAHGNSRFNHLFQSIKSRAITEQGQEEAW" | |
| "VQCCQWLSQSNKKVFLGQIGVCTIDQVLISVLPVKHRFIRGLGIGRSVLIVDEVHAYDTYMNGLLEAVLKAQADVGGSVI" | |
| "LLSATLPMKQKQKLLDTYGLHTDPVENNSAYPLINWRGVNGAQRFDLLAHPEQLPPRFSIQPEPICLADMLPDLTMLERM" | |
| "IAAANAGAQVCLICNLVDVAQVCYQRLKELNNTQVDIDLFHARFTLNDRREKENRVISNFGKNGKRNVGRILVATQVVEQ" | |
| "SLDVDFDWLITQHCPADLLFQRLGRLHRHHRKYRPAGFEIPVATILLPDGEGYGRHEHIYSNVRVMWRTQQHIEELNGAS" | |
| "LFFPDAYRQWLDSIYDDAEMDEPEWVGNGMDKFESAECEKRFKARKVLQWAEEYSLQDNDETILAVTRDGEMSLPLLPYV" | |
| "QTSSGKQLLDGQVYEDLSHEQQYEALALNRVNVPFTWKRSFSEVVDEDGLLWLEGKQNLDGWVWQGNSIVITYTGDEGMT" | |
| "RVIPANPK" | |
| ) | |
| _SPCAS9 = ( # Q99ZW2, S. pyogenes, 1368 aa — Type II-A effector (truncated by model to 1022 aa) | |
| "MDKKYSIGLDIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRIC" | |
| "YLQEIFSNEMAKVDDSFFHRLEESFLVEEDKKHERHPIFGNIVDEVAYHEKYPTIYHLRKKLVDSTDKADLRLIYLALAH" | |
| "MIKFRGHFLIEGDLNPDNSDVDKLFIQLVQTYNQLFEENPINASGVDAKAILSARLSKSRRLENLIAQLPGEKKNGLFGN" | |
| "LIALSLGLTPNFKSNFDLAEDAKLQLSKDTYDDDLDNLLAQIGDQYADLFLAAKNLSDAILLSDILRVNTEITKAPLSAS" | |
| "MIKRYDEHHQDLTLLKALVRQQLPEKYKEIFFDQSKNGYAGYIDGGASQEEFYKFIKPILEKMDGTEELLVKLNREDLLR" | |
| "KQRTFDNGSIPHQIHLGELHAILRRQEDFYPFLKDNREKIEKILTFRIPYYVGPLARGNSRFAWMTRKSEETITPWNFEE" | |
| "VVDKGASAQSFIERMTNFDKNLPNEKVLPKHSLLYEYFTVYNELTKVKYVTEGMRKPAFLSGEQKKAIVDLLFKTNRKVT" | |
| "VKQLKEDYFKKIECFDSVEISGVEDRFNASLGTYHDLLKIIKDKDFLDNEENEDILEDIVLTLTLFEDREMIEERLKTYA" | |
| "HLFDDKVMKQLKRRRYTGWGRLSRKLINGIRDKQSGKTILDFLKSDGFANRNFMQLIHDDSLTFKEDIQKAQVSGQGDSL" | |
| "HEHIANLAGSPAIKKGILQTVKVVDELVKVMGRHKPENIVIEMARENQTTQKGQKNSRERMKRIEEGIKELGSQILKEHP" | |
| "VENTQLQNEKLYLYYLQNGRDMYVDQELDINRLSDYDVDHIVPQSFLKDDSIDNKVLTRSDKNRGKSDNVPSEEVVKKMK" | |
| "NYWRQLLNAKLITQRKFDNLTKAERGGLSELDKAGFIKRQLVETRQITKHVAQILDSRMNTKYDENDKLIREVKVITLKS" | |
| "KLVSDFRKDFQFYKVREINNYHHAHDAYLNAVVGTALIKKYPKLESEFVYGDYKVYDVRKMIAKSEQEIGKATAKYFFYS" | |
| "NIMNFFKTEITLANGEIRKRPLIETNGETGEIVWDKGRDFATVRKVLSMPQVNIVKKTEVQTGGFSKESILPKRNSDKLI" | |
| "ARKKDWDPKKYGGFDSPTVAYSVLVVAKVEKGKSKKLKSVKELLGITIMERSSFEKNPIDFLEAKGYKEVKKDLIIKLPK" | |
| "YSLFELENGRKRMLASAGELQKGNELALPSKYVNFLYLASHYEKLKGSPEDNEQKQLFVEQHKHYLDEIIEQISEFSKRV" | |
| "ILADANLDKVLSAYNKHRDKPIREQAENIIHLFTLTNLGAPAAFKYFDTTIDRKRYTSTKEVLDATLIHQSITGLYETRI" | |
| "DLSQLGGD" | |
| ) | |
| _SACAS9 = ( # J7RUA5, S. aureus, 1053 aa — Type II-A effector (smaller Cas9 ortholog) | |
| "MKRNYILGLDIGITSVGYGIIDYETRDVIDAGVRLFKEANVENNEGRRSKRGARRLKRRRRHRIQRVKKLLFDYNLLTDH" | |
| "SELSGINPYEARVKGLSQKLSEEEFSAALLHLAKRRGVHNVNEVEEDTGNELSTKEQISRNSKALEEKYVAELQLERLKK" | |
| "DGEVRGSINRFKTSDYVKEAKQLLKVQKAYHQLDQSFIDTYIDLLETRRTYYEGPGEGSPFGWKDIKEWYEMLMGHCTYF" | |
| "PEELRSVKYAYNADLYNALNDLNNLVITRDENEKLEYYEKFQIIENVFKQKKKPTLKQIAKEILVNEEDIKGYRVTSTGK" | |
| "PEFTNLKVYHDIKDITARKEIIENAELLDQIAKILTIYQSSEDIQEELTNLNSELTQEEIEQISNLKGYTGTHNLSLKAI" | |
| "NLILDELWHTNDNQIAIFNRLKLVPKKVDLSQQKEIPTTLVDDFILSPVVKRSFIQSIKVINAIIKKYGLPNDIIIELAR" | |
| "EKNSKDAQKMINEMQKRNRQTNERIEEIIRTTGKENAKYLIEKIKLHDMQEGKCLYSLEAIPLEDLLNNPFNYEVDHIIP" | |
| "RSVSFDNSFNNKVLVKQEENSKKGNRTPFQYLSSSDSKISYETFKKHILNLAKGKGRISKTKKEYLLEERDINRFSVQKD" | |
| "FINRNLVDTRYATRGLMNLLRSYFRVNNLDVKVKSINGGFTSFLRRKWKFKKERNKGYKHHAEDALIIANADFIFKEWKK" | |
| "LDKAKKVMENQMFEEKQAESMPEIETEQEYKEIFITPHQIKHIKDFKDYKYSHRVDKKPNRELINDTLYSTRKDDKGNTL" | |
| "IVNNLNGLYDKDNDKLKKLINKSPEKLLMYHHDPQTYQKLKLIMEQYGDEKNPLYKYYEETGNYLTKYSKKDNGPVIKKI" | |
| "KYYGNKLNAHLDITDDYPNSRNKVVKLSLKPYRFDVYLDNGVYKFVTVKNLDVIKKENYYEVNSKCYEEAKKLKKISNQA" | |
| "EFIASFYNNDLIKINGELYRVIGVNNDLLNRIEVNMIDITYREYLENMNDKRPPRIIKTIASKTQSIKKYSTDILGNLYE" | |
| "VKSKKHPQIIKKG" | |
| ) | |
| _FNCAS12A = ( # A0Q7Q2, F. novicida, 1300 aa — Type V-A effector (Cpf1, truncated by model) | |
| "MSIYQEFVNKYSLSKTLRFELIPQGKTLENIKARGLILDDEKRAKDYKKAKQIIDKYHQFFIEEILSSVCISEDLLQNYS" | |
| "DVYFKLKKSDDDNLQKDFKSAKDTIKKQISEYIKDSEKFKNLFNQNLIDAKKGQESDLILWLKQSKDNGIELFKANSDIT" | |
| "DIDEALEIIKSFKGWTTYFKGFHENRKNVYSSNDIPTSIIYRIVDDNLPKFLENKAKYESLKDKAPEAINYEQIKKDLAE" | |
| "ELTFDIDYKTSEVNQRVFSLDEVFEIANFNNYLNQSGITKFNTIIGGKFVNGENTKRKGINEYINLYSQQINDKTLKKYK" | |
| "MSVLFKQILSDTESKSFVIDKLEDDSDVVTTMQSFYEQIAAFKTVEEKSIKETLSLLFDDLKAQKLDLSKIYFKNDKSLT" | |
| "DLSQQVFDDYSVIGTAVLEYITQQIAPKNLDNPSKKEQELIAKKTEKAKYLSLETIKLALEEFNKHRDIDKQCRFEEILA" | |
| "NFAAIPMIFDEIAQNKDNLAQISIKYQNQGKKDLLQASAEDDVKAIKDLLDQTNNLLHKLKIFHISQSEDKANILDKDEH" | |
| "FYLVFEECYFELANIVPLYNKIRNYITQKPYSDEKFKLNFENSTLANGWDKNKEPDNTAILFIKDDKYYLGVMNKKNNKI" | |
| "FDDKAIKENKGEGYKKIVYKLLPGANKMLPKVFFSAKSIKFYNPSEDILRIRNHSTHTKNGSPQKGYEKFEFNIEDCRKF" | |
| "IDFYKQSISKHPEWKDFGFRFSDTQRYNSIDEFYREVENQGYKLTFENISESYIDSVVNQGKLYLFQIYNKDFSAYSKGR" | |
| "PNLHTLYWKALFDERNLQDVVYKLNGEAELFYRKQSIPKKITHPAKEAIANKNKDNPKKESVFEYDLIKDKRFTEDKFFF" | |
| "HCPITINFKSSGANKFNDEINLLLKEKANDVHILSIDRGERHLAYYTLVDGKGNIIKQDTFNIIGNDRMKTNYHDKLAAI" | |
| "EKDRDSARKDWKKINNIKEMKEGYLSQVVHEIAKLVIEYNAIVVFEDLNFGFKRGRFKVEKQVYQKLEKMLIEKLNYLVF" | |
| "KDNEFDKTGGVLRAYQLTAPFETFKKMGKQTGIIYYVPAGFTSKICPVTGFVNQLYPKYESVSKSQEFFSKFDKICYNLD" | |
| "KGYFEFSFDYKNFGDKAAKGKWTIASFGSRLINFRNSDKNHNWDTREVYPTKELEKLLKDYSIEYGHGECIKAAICGESD" | |
| "KKFFAKLTSVLNTILQMRNSKTGTELDYLISPVADVNGNFFDSRQAPKNMPQDADANGAYHIGLKGLMLLGRIKNNQEGK" | |
| "KLNLVIKNEEYFEFVQNRNN" | |
| ) | |
| _LSHCAS13A = ( # P0DOC6, L. shahii, 1389 aa — Type VI-A effector (RNA-targeting, truncated) | |
| "MGNLFGHKRWYEVRDKKDFKIKRKVKVKRNYDGNKYILNINENNNKEKIDNNKFIRKYINYKKNDNILKEFTRKFHAGNI" | |
| "LFKLKGKEGIIRIENNDDFLETEEVVLYIEAYGKSEKLKALGITKKKIIDEAIRQGITKDDKKIEIKRQENEEEIEIDIR" | |
| "DEYTNKTLNDCSIILRIIENDELETKKSIYEIFKNINMSLYKIIEKIIENETEKVFENRYYEEHLREKLLKDDKIDVILT" | |
| "NFMEIREKIKSNLEILGFVKFYLNVGGDKKKSKNKKMLVEKILNINVDLTVEDIADFVIKELEFWNITKRIEKVKKVNNE" | |
| "FLEKRRNRTYIKSYVLLDKHEKFKIERENKKDKIVKFFVENIKNNSIKEKIEKILAEFKIDELIKKLEKELKKGNCDTEI" | |
| "FGIFKKHYKVNFDSKKFSKKSDEEKELYKIIYRYLKGRIEKILVNEQKVRLKKMEKIEIEKILNESILSEKILKRVKQYT" | |
| "LEHIMYLGKLRHNDIDMTTVNTDDFSRLHAKEELDLELITFFASTNMELNKIFSRENINNDENIDFFGGDREKNYVLDKK" | |
| "ILNSKIKIIRDLDFIDNKNNITNNFIRKFTKIGTNERNRILHAISKERDLQGTQDDYNKVINIIQNLKISDEEVSKALNL" | |
| "DVVFKDKKNIITKINDIKISEENNNDIKYLPSFSKVLPEILNLYRNNPKNEPFDTIETEKIVLNALIYVNKELYKKLILE" | |
| "DDLEENESKNIFLQELKKTLGNIDEIDENIIENYYKNAQISASKGNNKAIKKYQKKVIECYIGYLRKNYEELFDFSDFKM" | |
| "NIQEIKKQIKDINDNKTYERITVKTSDKTIVINDDFEYIISIFALLNSNAVINKIRNRFFATSVWLNTSEYQNIIDILDE" | |
| "IMQLNTLRNECITENWNLNLEEFIQKMKEIEKDFDDFKIQTKKEIFNNYYEDIKNNILTEFKDDINGCDVLEKKLEKIVI" | |
| "FDDETKFEIDKKSNILQDEQRKLSNINKKDLKKKVDQYIKDKDQEIKSKILCRIIFNSDFLKKYKKEIDNLIEDMESENE" | |
| "NKFQEIYYPKERKNELYIYKKNLFLNIGNPNFDKIYGLISNDIKMADAKFLFNIDGKNIRKNKISEIDAILKNLNDKLNG" | |
| "YSKEYKEKYIKKLKENDDFFAKNIQNKNYKSFEKDYNRVSEYKKIRDLVEFNYLNKIESYLIDINWKLAIQMARFERDMH" | |
| "YIVNGLRELGIIKLSGYNTGISRAYPKRNGSDGFYTTTAYYKFFDEESYKKFEKICYGFGIDLSENSEINKPENESIRNY" | |
| "ISHFYIVRNPFADYSIAEQIDRVSNLLSYSTRYNNSTYASVFEVFKKDVNLDYDELKKKFKLIGNNDILERLMKPKKVSV" | |
| "LELESYNSDYIKNLIIELLTKIENTNDTL" | |
| ) | |
| _ACRIIA4 = ( # A0A0E0UT28, Listeria monocytogenes phage, 87 aa — anti-Cas9 (inhibits SpCas9) | |
| "MNISELIREIKNKDYAVRLEGTDDNSITKLIIDVDNDGNEYVISESKNESIAEKFASTFKNGWNKEYEDEEEFYNDMQSI" | |
| "ILKSELN" | |
| ) | |
| # Default sequence shown in the Compare tab and as Protein A in the Distance tab | |
| EXAMPLE_PROTEIN = _SPCAS9 | |
| def process(sequence: str, top_k: int = 10, twin_aspect: str = "BP"): | |
| """Process protein sequence, compare embeddings, and search FAISS.""" | |
| sequence = strip_fasta_header(sequence.strip()) | |
| empty_df = pd.DataFrame(columns=["rank", "uniref50_id", "cosine", "uniprot"]) | |
| valid, error = validate_protein(sequence) | |
| if not valid: | |
| return f"**Error**: {error}", None, None, None, None, None, None, empty_df | |
| # Compute embeddings | |
| esm2_emb = embed_esm2(sequence) | |
| twin_emb = embed_twin(sequence, aspect=twin_aspect) | |
| # Compute stats | |
| esm2_stats = compute_stats(esm2_emb) | |
| twin_stats = compute_stats(twin_emb) | |
| # Save embeddings | |
| esm2_path = os.path.join(tempfile.gettempdir(), "esm2_embedding.npy") | |
| twin_path = os.path.join(tempfile.gettempdir(), "twin_embedding.npy") | |
| np.save(esm2_path, esm2_emb) | |
| np.save(twin_path, twin_emb) | |
| # Nearest neighbors in UniRef50 (GO-annotated subset). | |
| # FAISS index may not be uploaded yet (SLURM job still building) — in that | |
| # case, surface the error in the hits table but keep all embeddings/plots. | |
| try: | |
| hits_df = search_faiss(esm2_emb, k=int(top_k)) | |
| except Exception as e: | |
| print(f"FAISS search failed: {e}") | |
| hits_df = pd.DataFrame([{ | |
| "rank": 0, | |
| "uniref50_id": "(FAISS index not available)", | |
| "cosine": 0.0, | |
| "uniprot": str(e).splitlines()[0][:200], | |
| }]) | |
| trunc_note = (f"\n> ⚠️ Sequence truncated from {len(sequence)} to {ESM2_MAX_LEN} aa " | |
| f"(ESM2 position-embedding limit). Twin also truncates internally.\n" | |
| if len(sequence) > ESM2_MAX_LEN else "") | |
| summary = f"""### Results | |
| {trunc_note} | |
| | | ESM2 | Twin ({twin_aspect}) | | |
| |---|---|---| | |
| | Dimension | {esm2_stats['dim']} | {twin_stats['dim']} | | |
| | L2 Norm | {esm2_stats['l2_norm']:.2f} | {twin_stats['l2_norm']:.2f} | | |
| | Entropy | {esm2_stats['entropy']:.2f} | {twin_stats['entropy']:.2f} | | |
| | Sparsity | {esm2_stats['sparsity']:.1%} | {twin_stats['sparsity']:.1%} | | |
| Sequence: {len(sequence)} aa | |
| """ | |
| # Create visualizations | |
| esm2_heatmap = create_embedding_heatmap(esm2_emb, "ESM2 Embedding") | |
| twin_heatmap = create_embedding_heatmap(twin_emb, f"Twin Embedding ({twin_aspect})") | |
| comparison_plot = create_comparison_plot(esm2_stats, twin_stats) | |
| distribution_plot = create_distribution_plot(esm2_emb, twin_emb) | |
| return summary, esm2_path, twin_path, esm2_heatmap, twin_heatmap, comparison_plot, distribution_plot, hits_df | |
| # Build interface | |
| # When the URL contains a hash like #benchmark or ?tab=benchmark, auto-click | |
| # the matching tab button after the app loads. Allows deep-linking to any tab. | |
| _TAB_HASH_JS = """ | |
| function() { | |
| function pickTab(name) { | |
| if (!name) return; | |
| const want = name.toLowerCase(); | |
| const tryOnce = () => { | |
| const btns = document.querySelectorAll('button[role="tab"]'); | |
| for (const b of btns) { | |
| if ((b.textContent || "").trim().toLowerCase().startsWith(want)) { | |
| b.click(); | |
| return true; | |
| } | |
| } | |
| return false; | |
| }; | |
| let attempts = 0; | |
| const t = setInterval(() => { | |
| if (tryOnce() || ++attempts > 20) clearInterval(t); | |
| }, 200); | |
| } | |
| const hash = (window.location.hash || "").replace(/^#/, ""); | |
| const params = new URLSearchParams(window.location.search || ""); | |
| pickTab(hash || params.get("tab")); | |
| } | |
| """ | |
| with gr.Blocks( | |
| title="Functional Distance", | |
| css=".gradio-container { max-width: 100% !important; }", | |
| js=_TAB_HASH_JS, | |
| ) as demo: | |
| gr.Markdown( | |
| "# functional-distance\n" | |
| "Functional distance prediction for proteins — pairwise distance (Twin, " | |
| "GO-contrastive fine-tune of ESM2) and nearest-neighbor lookup (ESM2 or " | |
| "Twin-BP against UniRef50, plus CRISPR-focused lookup against a curated " | |
| "Cas/Acr reference set)." | |
| ) | |
| with gr.Tab("Distance"): | |
| gr.Markdown( | |
| "### Twin pairwise distance\n" | |
| "Compute the Twin model's **native trained distance** between two protein " | |
| "sequences under the selected GO aspect. " | |
| "**L2 distance** is on L2-normalized 1024-dim embeddings (the training convention, " | |
| "∈ [0, 2]); equivalent to `sqrt(2 - 2·cos_sim)`. Interpret relative to other pairs, " | |
| "not as an absolute threshold." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| dist_seq_a = gr.Textbox( | |
| label="Protein A", | |
| placeholder="Paste protein sequence (amino acids)...", | |
| lines=5, | |
| value=EXAMPLE_PROTEIN, | |
| ) | |
| dist_seq_b = gr.Textbox( | |
| label="Protein B", | |
| placeholder="Paste protein sequence (amino acids)...", | |
| lines=5, | |
| value=_SACAS9, | |
| info="Default: SaCas9 (S. aureus). Pair with the SpCas9 default → ortholog test." | |
| ) | |
| dist_aspect_radio = gr.Radio( | |
| choices=["BP", "CC", "MF"], | |
| value=TWIN_DEFAULT_ASPECT, | |
| label="Twin GO aspect", | |
| ) | |
| dist_btn = gr.Button("compute distance", variant="primary") | |
| with gr.Column(): | |
| dist_output = gr.HTML() | |
| def _distance_bar(value, v_min, v_max, labels, gradient_stops, fmt="{:+.4f}"): | |
| """Horizontal bar with a position marker, labels below.""" | |
| pct = max(0.0, min(100.0, (value - v_min) / (v_max - v_min) * 100.0)) | |
| return f""" | |
| <div style='margin:18px 0;'> | |
| <div style='display:flex;justify-content:space-between;margin-bottom:6px;font-size:14px;'> | |
| <span style='font-weight:600;'>{labels['title']}</span> | |
| <span style='font-family:ui-monospace,monospace;font-weight:600;'>{fmt.format(value)}</span> | |
| </div> | |
| <div style='position:relative;height:22px;background:linear-gradient(to right,{gradient_stops});border-radius:4px;'> | |
| <div style='position:absolute;left:{pct:.2f}%;top:-5px;width:3px;height:32px;background:#111;transform:translateX(-50%);box-shadow:0 0 0 2px #fff;'></div> | |
| </div> | |
| <div style='display:flex;justify-content:space-between;font-size:11px;color:#888;margin-top:4px;'> | |
| <span>{labels['left']}</span> | |
| <span>{labels['mid']}</span> | |
| <span>{labels['right']}</span> | |
| </div> | |
| </div>""" | |
| def _distance_handler(seq_a, seq_b, aspect): | |
| seq_a = strip_fasta_header(seq_a.strip()) | |
| seq_b = strip_fasta_header(seq_b.strip()) | |
| for name, seq in (("A", seq_a), ("B", seq_b)): | |
| valid, err = validate_protein(seq) | |
| if not valid: | |
| return f"<div style='color:#dc2626;font-weight:600;'>Error in sequence {name}: {err}</div>" | |
| trunc_notes = [] | |
| for name, seq in (("A", seq_a), ("B", seq_b)): | |
| if len(seq) > ESM2_MAX_LEN: | |
| trunc_notes.append(f"Protein {name} truncated from {len(seq)} → {ESM2_MAX_LEN} aa") | |
| trunc_html = (f"<div style='background:#fff7ed;border-left:3px solid #f97316;" | |
| f"padding:8px 12px;margin:8px 0;font-size:12px;color:#9a3412;'>" | |
| f"⚠️ {'; '.join(trunc_notes)} (ESM2 position-embedding limit).</div>" | |
| if trunc_notes else "") | |
| d = compute_distance(seq_a, seq_b, aspect) | |
| # Green = similar, red = dissimilar | |
| l2_bar = _distance_bar( | |
| d["l2"], 0.0, 2.0, | |
| labels={"title": "L2 distance (L2-normalized)", | |
| "left": "0 · identical", "mid": "√2 ≈ 1.41 · orthogonal", "right": "2 · opposite"}, | |
| gradient_stops="#4ade80 0%,#facc15 50%,#f87171 100%", | |
| fmt="{:.4f}", | |
| ) | |
| cos_bar = _distance_bar( | |
| d["cos_dist"], 0.0, 2.0, | |
| labels={"title": "cosine distance (1 − cos_sim)", | |
| "left": "0 · identical", "mid": "1 · orthogonal", "right": "2 · opposite"}, | |
| gradient_stops="#4ade80 0%,#facc15 50%,#f87171 100%", | |
| fmt="{:.4f}", | |
| ) | |
| return ( | |
| f"<h3 style='margin-top:0;'>Twin/{aspect} distance</h3>" | |
| f"{trunc_html}" | |
| f"{l2_bar}{cos_bar}" | |
| f"<p style='font-size:12px;color:#666;margin-top:16px;'>" | |
| f"cosine similarity = {d['cos_sim']:+.4f} · " | |
| f"sequences: A = {len(seq_a)} aa, B = {len(seq_b)} aa</p>" | |
| ) | |
| dist_btn.click( | |
| lambda a: "<div style='padding:16px;color:#666;'>⏳ Computing Twin distance…" | |
| "<br><span style='font-size:12px;'>First run of an aspect: ~15 s to load the checkpoint. Subsequent calls: <1 s.</span></div>", | |
| inputs=[dist_aspect_radio], | |
| outputs=[dist_output], | |
| show_progress="hidden", | |
| ).then( | |
| _distance_handler, | |
| inputs=[dist_seq_a, dist_seq_b, dist_aspect_radio], | |
| outputs=[dist_output], | |
| api_name="distance", | |
| show_progress="minimal", | |
| ) | |
| # CRISPR-themed example pairs. Click to populate the two sequence boxes. | |
| # Sequences declared as module-level constants (_SPCAS9, _CAS1_ECOLI, ...) | |
| # since they're also used as Compare-tab defaults. | |
| gr.Examples( | |
| examples=[ | |
| # [seq_a, seq_b, aspect] | |
| [_SPCAS9, _SPCAS9, "BP"], # sanity: identical -> L2 ~ 0 | |
| [_SPCAS9, _SACAS9, "BP"], # orthologs (both Type II-A Cas9) | |
| [_CAS1_ECOLI, _CAS2_ECOLI, "BP"], # adaptation complex partners | |
| [_SPCAS9, _FNCAS12A, "BP"], # Type II vs Type V (both DNA-targeting) | |
| [_CAS1_ECOLI, _CAS3_ECOLI, "BP"], # Type I adaptation vs interference | |
| [_SPCAS9, _LSHCAS13A, "BP"], # different substrate: DNA vs RNA | |
| [_ACRIIA4, _SPCAS9, "BP"], # anti-CRISPR vs its target | |
| ], | |
| inputs=[dist_seq_a, dist_seq_b, dist_aspect_radio], | |
| example_labels=[ | |
| "Identical (sanity) — SpCas9 / SpCas9", | |
| "Orthologs — SpCas9 / SaCas9 (both Type II-A)", | |
| "Adaptation partners — Cas1 / Cas2 (E. coli)", | |
| "Different CRISPR types — SpCas9 (II) / FnCas12a (V) — DNA-targeting", | |
| "Same pathway, different role — Cas1 (adaptation) / Cas3 (interference), Type I", | |
| "Different substrate — SpCas9 (DNA) / LshCas13a (RNA)", | |
| "Anti-CRISPR vs target — AcrIIA4 / SpCas9", | |
| ], | |
| label="Example pairs (CRISPR / anti-CRISPR) — click to load", | |
| ) | |
| with gr.Tab("Characterize Unknown Proteins"): | |
| gr.Markdown( | |
| "### Characterize unknown proteins\n" | |
| "Paste a protein sequence, choose an embedding, then compare it either to the " | |
| "large UniRef50 FAISS background or to the curated CRISPR protein reference. " | |
| "For the CRISPR reference this page now also reports a heuristic family verdict " | |
| "and places the query into the same Cas/Acr reference projection used by the benchmark.\n" | |
| "- **UniRef50** — broad nearest-neighbour lookup across the GO-annotated UniRef50 subset " | |
| "using either ESM2 or the Twin-BP index.\n" | |
| "- **CRISPR curated** — mixed **254-protein** reference: 175 Pfam-verified Cas proteins " | |
| "plus 79 Anti-CRISPRdb v3/local Acr proteins." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=320): | |
| lookup_seq_input = gr.Textbox( | |
| label="Protein Sequence", | |
| placeholder="Paste protein sequence (amino acids)...", | |
| lines=6, | |
| value=EXAMPLE_PROTEIN, | |
| info="FASTA or raw sequence; > 1022 aa is truncated" | |
| ) | |
| lookup_method_radio = gr.Radio( | |
| choices=["ESM2 baseline", "genomenet-twin (MF)", "genomenet-twin (BP)"], | |
| value="ESM2 baseline", | |
| label="Method", | |
| ) | |
| lookup_index_radio = gr.Radio( | |
| choices=["UniRef50", "CRISPR curated"], | |
| value="UniRef50", | |
| label="Reference set", | |
| ) | |
| lookup_btn = gr.Button("search", variant="primary") | |
| with gr.Column(scale=2, min_width=400): | |
| lookup_info = gr.Markdown() | |
| lookup_verdict = gr.Markdown() | |
| lookup_plot = gr.Plot(label="CRISPR reference placement") | |
| lookup_hits = gr.HTML() | |
| _lookup_empty_html = "" | |
| def _hits_bar_cell(cos, width=180): | |
| """Red -> yellow -> green gradient; black marker at the cosine position (0..1).""" | |
| pct = max(0.0, min(100.0, float(cos) * 100.0)) | |
| return ( | |
| f"<div style='display:flex;align-items:center;gap:10px;'>" | |
| f"<div style='position:relative;width:{width}px;height:12px;" | |
| f"background:linear-gradient(to right,#f87171 0%,#facc15 50%,#4ade80 100%);" | |
| f"border-radius:3px;'>" | |
| f"<div style='position:absolute;left:{pct:.2f}%;top:-3px;width:2px;height:18px;" | |
| f"background:#111;box-shadow:0 0 0 2px #fff;transform:translateX(-50%);'></div>" | |
| f"</div>" | |
| f"<span style='font-family:ui-monospace,monospace;font-size:12px;min-width:52px;'>{cos:.4f}</span>" | |
| f"</div>" | |
| ) | |
| def _bar_scale_normalized(cos, lo, hi): | |
| """Position cos on a scale with explicit lo/hi endpoints (for the per-row marker).""" | |
| if hi <= lo: | |
| return 50.0 | |
| return max(0.0, min(100.0, (float(cos) - lo) / (hi - lo) * 100.0)) | |
| def _hits_table_section(df, title, lo, hi): | |
| """Render a hits DataFrame as an HTML table. Bars use a shared [lo, hi] scale.""" | |
| if df.empty: | |
| return f"<h4>{title}</h4><p style='color:#888;'>(empty)</p>" | |
| rows = [] | |
| for _, r in df.iterrows(): | |
| pct = _bar_scale_normalized(r["cosine"], lo, hi) | |
| bar = ( | |
| f"<div style='display:flex;align-items:center;gap:10px;'>" | |
| f"<div style='position:relative;width:180px;height:12px;" | |
| f"background:linear-gradient(to right,#f87171 0%,#facc15 50%,#4ade80 100%);border-radius:3px;'>" | |
| f"<div style='position:absolute;left:{pct:.2f}%;top:-3px;width:2px;height:18px;" | |
| f"background:#111;box-shadow:0 0 0 2px #fff;transform:translateX(-50%);'></div>" | |
| f"</div>" | |
| f"<span style='font-family:ui-monospace,monospace;font-size:12px;min-width:60px;'>{r['cosine']:+.4f}</span>" | |
| f"</div>" | |
| ) | |
| link = (f"<a href='{r['link']}' target='_blank' rel='noopener noreferrer' " | |
| f"style='color:#2563eb;text-decoration:none;'>↗</a>" | |
| if str(r.get("link", "")).startswith("http") else "") | |
| desc = str(r.get("description", "") or "") | |
| if len(desc) > 140: | |
| desc = desc[:140] + "…" | |
| rows.append( | |
| f"<tr>" | |
| f"<td style='padding:5px 10px;color:#555;'>{r['rank']}</td>" | |
| f"<td style='padding:5px 10px;font-family:ui-monospace,monospace;white-space:nowrap;'>{r['id']}</td>" | |
| f"<td style='padding:5px 10px;'>{bar}</td>" | |
| f"<td style='padding:5px 10px;font-size:13px;color:#374151;'>{desc}</td>" | |
| f"<td style='padding:5px 10px;'>{link}</td>" | |
| f"</tr>" | |
| ) | |
| return ( | |
| f"<h4 style='margin:18px 0 6px 0;'>{title}</h4>" | |
| f"<table style='width:100%;border-collapse:collapse;font-size:13px;'>" | |
| f"<thead><tr style='text-align:left;border-bottom:1px solid #e5e7eb;color:#6b7280;font-size:12px;'>" | |
| f"<th style='padding:5px 10px;'>#</th><th style='padding:5px 10px;'>ID</th>" | |
| f"<th style='padding:5px 10px;'>similarity</th><th style='padding:5px 10px;'>description</th>" | |
| f"<th style='padding:5px 10px;'></th></tr></thead>" | |
| f"<tbody>{''.join(rows)}</tbody></table>" | |
| ) | |
| def _distribution_svg(top_display, bot_display, top_context, bot_context, | |
| width=760, height=150, n_bins=64): | |
| """Stacked histogram of every retrieved top/bottom FAISS hit. | |
| Metadata is fetched only for the displayed table rows, but the | |
| histogram includes the full retrieved extreme pool. | |
| """ | |
| all_v = list(top_display) + list(bot_display) + list(top_context) + list(bot_context) | |
| if not all_v: | |
| return "" | |
| lo, hi = min(all_v), max(all_v) | |
| pad = (hi - lo) * 0.05 if hi > lo else 0.01 | |
| x_lo, x_hi = lo - pad, hi + pad | |
| def _bin(v): | |
| return min(n_bins - 1, max(0, int((v - x_lo) / (x_hi - x_lo) * n_bins))) | |
| top_counts = [0] * n_bins | |
| top_context_counts = [0] * n_bins | |
| bot_counts = [0] * n_bins | |
| bot_context_counts = [0] * n_bins | |
| for v in top_display: | |
| top_counts[_bin(v)] += 1 | |
| for v in top_context: | |
| top_context_counts[_bin(v)] += 1 | |
| for v in bot_display: | |
| bot_counts[_bin(v)] += 1 | |
| for v in bot_context: | |
| bot_context_counts[_bin(v)] += 1 | |
| max_c = max( | |
| t + tc + b + bc | |
| for t, tc, b, bc in zip(top_counts, top_context_counts, bot_counts, bot_context_counts) | |
| ) or 1 | |
| bar_w = width / n_bins | |
| usable_h = height - 42 | |
| bars = [] | |
| for i in range(n_bins): | |
| y = height - 28 | |
| segments = [ | |
| (bot_context_counts[i] / max_c * usable_h, "#fecaca"), | |
| (bot_counts[i] / max_c * usable_h, "#ef4444"), | |
| (top_context_counts[i] / max_c * usable_h, "#bbf7d0"), | |
| (top_counts[i] / max_c * usable_h, "#16a34a"), | |
| ] | |
| for h, colour in segments: | |
| if h > 0: | |
| bars.append( | |
| f"<rect x='{i*bar_w:.2f}' y='{y-h:.2f}' width='{max(bar_w-1, 1):.2f}' " | |
| f"height='{h:.2f}' fill='{colour}' />" | |
| ) | |
| y -= h | |
| total_n = len(all_v) | |
| axis_y = height - 27 | |
| labels = ( | |
| f"<line x1='0' y1='{axis_y}' x2='{width}' y2='{axis_y}' stroke='#9ca3af' stroke-width='1' />" | |
| f"<text x='0' y='{height-9}' fill='#6b7280' font-size='10' font-family='sans-serif'>{x_lo:+.2f}</text>" | |
| f"<text x='{width/2}' y='{height-9}' text-anchor='middle' fill='#6b7280' font-size='10' font-family='sans-serif'>cosine similarity across retrieved extreme hits (n={total_n})</text>" | |
| f"<text x='{width}' y='{height-9}' text-anchor='end' fill='#6b7280' font-size='10' font-family='sans-serif'>{x_hi:+.2f}</text>" | |
| ) | |
| legend = ( | |
| f"<div style='display:flex;align-items:center;gap:14px;flex-wrap:wrap;font-size:11px;color:#6b7280;margin-top:4px;'>" | |
| f"<span style='display:inline-flex;align-items:center;gap:5px;'>" | |
| f"<span style='width:10px;height:10px;background:#16a34a;display:inline-block;border-radius:2px;'></span>" | |
| f"top {len(top_display)} with metadata/table</span>" | |
| f"<span style='display:inline-flex;align-items:center;gap:5px;'>" | |
| f"<span style='width:10px;height:10px;background:#bbf7d0;display:inline-block;border-radius:2px;'></span>" | |
| f"next {len(top_context)} top hits</span>" | |
| f"<span style='display:inline-flex;align-items:center;gap:5px;'>" | |
| f"<span style='width:10px;height:10px;background:#ef4444;display:inline-block;border-radius:2px;'></span>" | |
| f"bottom {len(bot_display)} with metadata/table</span>" | |
| f"<span style='display:inline-flex;align-items:center;gap:5px;'>" | |
| f"<span style='width:10px;height:10px;background:#fecaca;display:inline-block;border-radius:2px;'></span>" | |
| f"next {len(bot_context)} anti-correlated hits</span>" | |
| f"</div>" | |
| f"<div style='font-size:11px;color:#6b7280;margin-top:3px;'>" | |
| f"All retrieved hits are plotted here; metadata is fetched only for the table rows. " | |
| f"This is an extreme-neighbour view, not a full scan of every UniRef50 cluster.</div>" | |
| ) | |
| return ( | |
| f"<div style='margin:8px 0 14px 0;'>" | |
| f"<svg width='100%' viewBox='0 0 {width} {height}' preserveAspectRatio='none' " | |
| f"style='display:block;background:#fafafa;border:1px solid #e5e7eb;border-radius:4px;'>" | |
| f"{''.join(bars)}{labels}</svg>" | |
| f"{legend}</div>" | |
| ) | |
| DISPLAY_K = 25 | |
| POOL_K = 500 # larger FAISS fetch (no metadata) for the distribution histogram | |
| def _lookup_method_spec(method_choice): | |
| if method_choice.startswith("genomenet-twin"): | |
| aspect = "MF" if "(MF)" in method_choice else "BP" | |
| return { | |
| "is_twin": True, | |
| "aspect": aspect, | |
| "method_key": f"Twin-{aspect}", | |
| "method_label": f"genomenet-twin ({aspect})", | |
| "threshold": 0.70, | |
| "dim_info": f"1024-dim Twin-{aspect}", | |
| } | |
| return { | |
| "is_twin": False, | |
| "aspect": None, | |
| "method_key": "ESM2 baseline", | |
| "method_label": "ESM2 baseline", | |
| "threshold": 0.90, | |
| "dim_info": "1280-dim ESM2", | |
| } | |
| def _crispr_verdict_card(label, conf, top_cos, threshold, method_label): | |
| conf_colour = {"high": "#059669", "medium": "#d97706", "low": "#dc2626"}.get(conf, "#6b7280") | |
| return ( | |
| f"### CRISPR characterization ({method_label})\n" | |
| f"<div style='border-left:4px solid {conf_colour};padding:10px 14px;" | |
| f"background:#fafafa;border-radius:4px;margin:6px 0;'>" | |
| f"<div style='font-size:15px;'><b>{label}</b></div>" | |
| f"<div style='font-size:12px;color:#6b7280;margin-top:4px;'>" | |
| f"confidence: <b style='color:{conf_colour}'>{conf.upper()}</b> · " | |
| f"top-1 cosine = {top_cos:+.3f} · heuristic threshold = {threshold:.2f}" | |
| f"</div></div>" | |
| ) | |
| def _make_crispr_query_plot(method_key, method_label, ref_meta, sims, top): | |
| xy_ref = get_crispr_umap(method_key) | |
| if xy_ref is None: | |
| return None | |
| w = np.exp((sims[top[:5]] - sims[top[:5]].max()) * 8.0) | |
| w /= w.sum() | |
| q_xy = (xy_ref[top[:5]] * w[:, None]).sum(axis=0) | |
| from matplotlib.lines import Line2D | |
| all_fams = list(dict.fromkeys(ref_meta["family"].tolist())) | |
| cmap = plt.get_cmap("tab20", len(all_fams)) | |
| fam_colour = {f: cmap(i) for i, f in enumerate(all_fams)} | |
| group_marker = {"cas": "o", "acr": "^"} | |
| fig, ax = plt.subplots(figsize=(9, 6.5)) | |
| for i, r in ref_meta.iterrows(): | |
| ax.scatter( | |
| xy_ref[i, 0], xy_ref[i, 1], | |
| color=fam_colour[r["family"]], | |
| marker=group_marker.get(str(r["group"]), "o"), | |
| s=48, edgecolor="black", linewidth=0.35, alpha=0.85, zorder=2, | |
| ) | |
| for i in top[:3]: | |
| ax.plot([q_xy[0], xy_ref[i, 0]], [q_xy[1], xy_ref[i, 1]], | |
| color="#444", lw=0.6, linestyle="--", alpha=0.6, zorder=3) | |
| ax.scatter(q_xy[0], q_xy[1], marker="*", s=320, color="#111", | |
| edgecolor="gold", linewidth=2.0, zorder=10) | |
| handles = [ | |
| Line2D([0], [0], marker="o", color="w", markerfacecolor=fam_colour[f], | |
| markeredgecolor="black", markersize=7, label=f) | |
| for f in all_fams | |
| ] | |
| handles.append(Line2D([0], [0], marker="*", color="w", markerfacecolor="#111", | |
| markeredgecolor="gold", markersize=14, label="query")) | |
| ax.legend(handles=handles, loc="center left", bbox_to_anchor=(1.02, 0.5), | |
| fontsize=7.2, frameon=False) | |
| ax.set_xlabel("projection-1") | |
| ax.set_ylabel("projection-2") | |
| ax.set_title(f"Query placement in the mixed Cas/Acr reference ({method_label})") | |
| ax.grid(alpha=0.25, linestyle=":") | |
| plt.tight_layout() | |
| return fig | |
| def _lookup_handler(sequence, method_choice, index_choice): | |
| sequence = strip_fasta_header(sequence.strip()) | |
| valid, err = validate_protein(sequence) | |
| if not valid: | |
| return f"**Error**: {err}", "", None, _lookup_empty_html | |
| trunc_note = (f"> ⚠️ Query truncated from {len(sequence)} to {ESM2_MAX_LEN} aa " | |
| f"(ESM2 limit).\n\n" if len(sequence) > ESM2_MAX_LEN else "") | |
| spec = _lookup_method_spec(method_choice) | |
| if index_choice.startswith("UniRef50"): | |
| try: | |
| if spec["is_twin"]: | |
| if spec["aspect"] != "BP": | |
| return ( | |
| f"{trunc_note}**Twin-MF × UniRef50** is not packaged yet.\n\n" | |
| "The full UniRef50 FAISS index currently available for the Twin model " | |
| "uses the **BP** aspect. Use `genomenet-twin (BP)` for full UniRef50 " | |
| "lookup, or use Twin-MF against the curated CRISPR reference.", | |
| "", | |
| None, | |
| _lookup_empty_html, | |
| ) | |
| query_emb = embed_twin(sequence, aspect="BP") | |
| faiss_index_name = "twin-bp" | |
| else: | |
| query_emb = embed_esm2(sequence) | |
| faiss_index_name = "esm2" | |
| faiss_cfg = FAISS_CONFIGS[faiss_index_name] | |
| # Large pool (no metadata) for context, small display-subset (with metadata). | |
| top_pool = search_faiss(query_emb, k=POOL_K, fetch_metadata=False, index_name=faiss_index_name) | |
| bot_pool = search_faiss(query_emb, k=POOL_K, negate=True, fetch_metadata=False, index_name=faiss_index_name) | |
| # Now fetch metadata for display subsets only | |
| top_display_ids = top_pool["uniref50_id"].head(DISPLAY_K).tolist() | |
| bot_display_ids = bot_pool["uniref50_id"].head(DISPLAY_K).tolist() | |
| meta = fetch_uniref_metadata(top_display_ids + bot_display_ids) | |
| def _to_display(pool_df, ids_subset, rank_start=1): | |
| sub = pool_df[pool_df["uniref50_id"].isin(ids_subset)].copy() | |
| sub = sub.set_index("uniref50_id").reindex(ids_subset).reset_index() | |
| sub["rank"] = range(rank_start, rank_start + len(sub)) | |
| return pd.DataFrame({ | |
| "rank": sub["rank"], | |
| "id": sub["uniref50_id"], | |
| "cosine": sub["cosine"], | |
| "description": [meta.get(i, "") for i in sub["uniref50_id"]], | |
| "link": ["https://www.uniprot.org/uniref/" + i for i in sub["uniref50_id"]], | |
| }) | |
| top_df = _to_display(top_pool, top_display_ids) | |
| bot_df = _to_display(bot_pool, bot_display_ids) | |
| # Histogram pools | |
| top_25_cos = top_pool["cosine"].head(DISPLAY_K).tolist() | |
| top_ctx_cos = top_pool["cosine"].iloc[DISPLAY_K:].tolist() | |
| bot_25_cos = bot_pool["cosine"].head(DISPLAY_K).tolist() | |
| bot_ctx_cos = bot_pool["cosine"].iloc[DISPLAY_K:].tolist() | |
| all_cos = top_df["cosine"].tolist() + bot_df["cosine"].tolist() | |
| lo, hi = (min(all_cos), max(all_cos)) if all_cos else (0.0, 1.0) | |
| info = ( | |
| f"{trunc_note}" | |
| f"**UniRef50 × {faiss_cfg['label']}** — top-{DISPLAY_K} most similar + bottom-{DISPLAY_K} " | |
| f"most anti-correlated clusters (cosine on L2-normalized {faiss_cfg['dim_info']} embeddings). " | |
| f"The histogram plots all {len(top_pool) + len(bot_pool)} retrieved extreme hits " | |
| f"({len(top_pool)} top-neighbour hits + {len(bot_pool)} anti-correlated hits); " | |
| f"metadata is fetched only for the table rows. \n" | |
| f"Top range: **{top_df['cosine'].max():+.3f}** – **{top_df['cosine'].min():+.3f}** · " | |
| f"Bottom range: **{bot_df['cosine'].max():+.3f}** – **{bot_df['cosine'].min():+.3f}**" | |
| ) | |
| html = ( | |
| _distribution_svg(top_25_cos, bot_25_cos, top_ctx_cos, bot_ctx_cos) | |
| + _hits_table_section(top_df, f"🟢 Top-{DISPLAY_K} most similar", lo, hi) | |
| + _hits_table_section(bot_df, f"🔴 Bottom-{DISPLAY_K} most anti-correlated", lo, hi) | |
| ) | |
| return info, "", None, html | |
| except Exception as e: | |
| return (f"{trunc_note}**FAISS lookup failed**: {str(e).splitlines()[0]}\n\n" | |
| f"The UniRef50 FAISS index may not be available yet.", | |
| "", | |
| None, | |
| _lookup_empty_html) | |
| else: # CRISPR curated reference | |
| try: | |
| if spec["is_twin"]: | |
| query_emb = embed_twin(sequence, aspect=spec["aspect"]) | |
| else: | |
| query_emb = embed_esm2(sequence) | |
| method_label = spec["method_label"] | |
| method_key = spec["method_key"] | |
| emb_dim_info = spec["dim_info"] | |
| ref_emb, _ = get_crispr_reference(method_key) | |
| _, ref_meta = get_crispr_reference(method_key) | |
| n_ref = ref_emb.shape[0] | |
| q = np.asarray(query_emb, dtype=np.float32).reshape(-1) | |
| q = q / (np.linalg.norm(q) + 1e-9) | |
| sims = ref_emb @ q | |
| order = np.argsort(-sims) | |
| k_eff = min(DISPLAY_K, n_ref // 2) | |
| top_idx = order[:k_eff] | |
| bot_idx = order[-k_eff:][::-1] | |
| def _display_rows(indices): | |
| rows = [] | |
| for rank, idx in enumerate(indices, 1): | |
| m = ref_meta.iloc[idx] | |
| rows.append({ | |
| "rank": rank, | |
| "id": _crispr_display_id(m), | |
| "cosine": round(float(sims[idx]), 4), | |
| "description": ( | |
| f"{_cell_text(m.get('family', ''))} · " | |
| f"{_cell_text(m.get('type', ''))} · " | |
| f"{_cell_text(m.get('organism', ''))}" | |
| ), | |
| "link": _crispr_protein_link(m), | |
| }) | |
| return pd.DataFrame(rows) | |
| top_df, bot_df = _display_rows(top_idx), _display_rows(bot_idx) | |
| # Context for histogram = the middle proteins (not in top or bottom k) | |
| middle = [float(sims[i]) for i in order[k_eff:n_ref - k_eff]] | |
| all_cos = top_df["cosine"].tolist() + bot_df["cosine"].tolist() | |
| lo, hi = (min(all_cos), max(all_cos)) if all_cos else (0.0, 1.0) | |
| top10_idx = order[:10] | |
| top_rows = [] | |
| for rank, idx in enumerate(top10_idx, 1): | |
| m = ref_meta.iloc[idx] | |
| top_rows.append({ | |
| "rank": rank, | |
| "id": _crispr_display_id(m), | |
| "cosine": float(sims[idx]), | |
| "family": _cell_text(m.get("family", "")), | |
| "group": _cell_text(m.get("group", "")), | |
| "organism": _cell_text(m.get("organism", "")), | |
| "name": _cell_text(m.get("name", "")), | |
| }) | |
| top_cos = top_rows[0]["cosine"] | |
| label, conf = _crispr_verdict(top_rows, top_cos, spec["threshold"]) | |
| verdict_md = _crispr_verdict_card(label, conf, top_cos, spec["threshold"], method_label) | |
| fig = _make_crispr_query_plot(method_key, method_label, ref_meta, sims, top10_idx) | |
| info = ( | |
| f"{trunc_note}" | |
| f"**CRISPR curated × {method_label}** — {n_ref}-protein mixed Cas/Acr set " | |
| f"(175 Pfam-verified Cas + 79 Anti-CRISPRdb v3/local Acr proteins). " | |
| f"Top-{k_eff} most similar + bottom-{k_eff} most anti-correlated (cosine on " | |
| f"{emb_dim_info} embeddings). Histogram also shows the {len(middle)} " | |
| f"middle proteins (grey). \n" | |
| f"Top range: **{top_df['cosine'].max():+.3f}** – **{top_df['cosine'].min():+.3f}** · " | |
| f"Bottom range: **{bot_df['cosine'].max():+.3f}** – **{bot_df['cosine'].min():+.3f}**" | |
| ) | |
| html = ( | |
| _distribution_svg(top_df["cosine"].tolist(), bot_df["cosine"].tolist(), middle, []) | |
| + _hits_table_section(top_df, f"🟢 Top-{k_eff} most similar in CRISPR set", lo, hi) | |
| + _hits_table_section(bot_df, f"🔴 Bottom-{k_eff} most anti-correlated in CRISPR set", lo, hi) | |
| ) | |
| return info, verdict_md, fig, html | |
| except Exception as e: | |
| return (f"{trunc_note}**CRISPR reference lookup failed**: {str(e).splitlines()[0]}\n\n" | |
| f"The curated CRISPR embeddings may not be packaged into this Space yet.", | |
| "", | |
| None, | |
| _lookup_empty_html) | |
| lookup_btn.click( | |
| lambda m, r: (f"⏳ Searching {m} × {r}…", "", None, _lookup_empty_html), | |
| inputs=[lookup_method_radio, lookup_index_radio], | |
| outputs=[lookup_info, lookup_verdict, lookup_plot, lookup_hits], | |
| show_progress="hidden", | |
| ).then( | |
| _lookup_handler, | |
| inputs=[lookup_seq_input, lookup_method_radio, lookup_index_radio], | |
| outputs=[lookup_info, lookup_verdict, lookup_plot, lookup_hits], | |
| api_name="lookup", | |
| show_progress="minimal", | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| # UniRef50 × ESM2 (full-scale FAISS search over 4.3M clusters) | |
| [_CAS1_ECOLI, "ESM2 baseline", "UniRef50"], | |
| [_SPCAS9, "ESM2 baseline", "UniRef50"], | |
| [_ACRIIA4, "ESM2 baseline", "UniRef50"], | |
| [_CAS1_ECOLI, "genomenet-twin (BP)", "UniRef50"], | |
| [_SPCAS9, "genomenet-twin (BP)", "UniRef50"], | |
| # CRISPR curated × Twin-MF (recommended for top-hit characterization) | |
| [_CAS1_ECOLI, "genomenet-twin (MF)", "CRISPR curated"], | |
| [_SPCAS9, "genomenet-twin (MF)", "CRISPR curated"], | |
| [_FNCAS12A, "genomenet-twin (MF)", "CRISPR curated"], | |
| [_LSHCAS13A, "genomenet-twin (MF)", "CRISPR curated"], | |
| [_ACRIIA4, "genomenet-twin (MF)", "CRISPR curated"], | |
| # CRISPR curated × ESM2 (sequence-level baseline for comparison) | |
| [_SPCAS9, "ESM2 baseline", "CRISPR curated"], | |
| [_ACRIIA4, "ESM2 baseline", "CRISPR curated"], | |
| ], | |
| inputs=[lookup_seq_input, lookup_method_radio, lookup_index_radio], | |
| example_labels=[ | |
| "Cas1 × UniRef50 (ESM2) — adaptation, E. coli", | |
| "SpCas9 × UniRef50 (ESM2) — Type II-A effector", | |
| "AcrIIA4 × UniRef50 (ESM2) — anti-CRISPR", | |
| "Cas1 × UniRef50 (Twin-BP) — full functional-index lookup", | |
| "SpCas9 × UniRef50 (Twin-BP) — full functional-index lookup", | |
| "Cas1 × CRISPR (Twin-MF) — adaptation, E. coli", | |
| "SpCas9 × CRISPR (Twin-MF) — Type II-A effector", | |
| "FnCas12a × CRISPR (Twin-MF) — Type V-A effector", | |
| "LshCas13a × CRISPR (Twin-MF) — Type VI-A, RNA-targeting", | |
| "AcrIIA4 × CRISPR (Twin-MF) — anti-Cas9 inhibitor", | |
| "SpCas9 × CRISPR (ESM2) — sequence-level baseline", | |
| "AcrIIA4 × CRISPR (ESM2) — sequence-level baseline", | |
| ], | |
| label="Example queries — click to load a query × method × reference combo", | |
| ) | |
| if False: | |
| gr.Markdown( | |
| "### Characterize an unknown protein\n\n" | |
| "Designed for triaging uncharacterised ORFs: is this protein CRISPR-adjacent at " | |
| "all, and if so, which family does it look like? Two sections:\n" | |
| "1. **Pre-computed gallery** — we fetched 24 genuinely uncharacterised proteins " | |
| "from UniProt (phage and bacterial hypotheticals, prime candidates for novel " | |
| "Cas / anti-CRISPR discoveries) and ran both ESM2 and genomenet-twin (MF) on each. " | |
| "Results are shown below in an interactive 3-D UMAP plus a full table.\n" | |
| "2. **Bring your own sequence** — paste any protein at the bottom to run the same " | |
| "analysis interactively." | |
| ) | |
| # --- Section 1: pre-computed gallery ------------------------------- | |
| # Summary stats | |
| gallery_data = get_unknown_results() | |
| n_unknown = gallery_data.get("n_queries", 0) | |
| if n_unknown > 0: | |
| recs = gallery_data["records"] | |
| def _count_by_verdict(method_key): | |
| pred = sum(1 for r in recs if "Predicted" in r[method_key]["verdict"]) | |
| cris = sum(1 for r in recs if r[method_key]["confidence"] != "low") | |
| return pred, cris | |
| pred_mf, cris_mf = _count_by_verdict("twin-mf") | |
| pred_es, cris_es = _count_by_verdict("esm2") | |
| gr.Markdown( | |
| f"#### 1. Pre-computed gallery · {n_unknown} uncharacterised proteins\n\n" | |
| f"| method | 'Predicted: <family>' | any CRISPR signal (above threshold) |\n" | |
| f"|---|---|---|\n" | |
| f"| genomenet-twin (MF) | **{pred_mf}** / {n_unknown} | **{cris_mf}** / {n_unknown} |\n" | |
| f"| ESM2 baseline | {pred_es} / {n_unknown} | {cris_es} / {n_unknown} |\n\n" | |
| f"Twin-MF flags substantially more phage hypotheticals as CRISPR-related than " | |
| f"ESM2 — expected given the benchmark results, and biologically plausible: " | |
| f"phages often carry anti-CRISPR proteins to evade the host's CRISPR-Cas " | |
| f"defence. Rebuild this gallery from " | |
| f"`scripts/analyses/crispr/precompute_unknowns.py` (GPU job, ~3 min)." | |
| ) | |
| gr.Markdown( | |
| "**Signal-strength benchmark — ESM2 vs Twin-MF on the same unknowns.** " | |
| "Three views of the same data: histogram of top-1 cosines, cumulative " | |
| "\"fraction of unknowns with cosine ≥ x\", and verdict-confidence bucket " | |
| "counts. Twin-MF consistently shifts the distribution rightward and " | |
| "promotes more unknowns into the medium/high-confidence buckets than " | |
| "ESM2 does — the direct answer to *how much more signal does Twin find " | |
| "than ESM2 on actual unknowns*." | |
| ) | |
| gr.Image( | |
| value="data/benchmark/crispr_unknowns_distribution.png", | |
| label=None, show_label=False, container=False, | |
| ) | |
| gallery_method_radio = gr.Radio( | |
| choices=["genomenet-twin (MF)", "ESM2 baseline"], | |
| value="genomenet-twin (MF)", | |
| label="Gallery view method", | |
| info="Switches the 3-D UMAP overview and the verdict column in the table below.", | |
| ) | |
| gallery_plot = gr.Plot(label="3-D UMAP — reference (by family) + unknowns (by verdict)") | |
| gallery_table = gr.HTML( | |
| "<div style='padding:12px;color:#6b7280;'>Click <b>load gallery</b> to render the pre-computed unknown-protein overview.</div>" | |
| ) | |
| gallery_load_btn = gr.Button("load gallery", variant="secondary") | |
| def _make_gallery_3d_plot(method_choice): | |
| import plotly.graph_objects as go | |
| method_key = "Twin-MF" if method_choice.startswith("genomenet") else "ESM2 baseline" | |
| mkey = "twin-mf" if method_key == "Twin-MF" else "esm2" | |
| xyz = get_crispr_umap_3d(method_key) | |
| if xyz is None: | |
| return go.Figure() | |
| _, ref_meta = get_crispr_reference(method_key) | |
| fams = list(dict.fromkeys(ref_meta["family"].tolist())) | |
| # Consistent palette using Plotly's D3 then Pastel | |
| palette = ( | |
| px_colors_pool := [ | |
| "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", | |
| "#e377c2", "#7f7f7f", "#bcbd22", "#17becf", "#aec7e8", "#ffbb78", | |
| "#98df8a", "#ff9896", "#c5b0d5", "#c49c94", "#f7b6d3", "#c7c7c7", | |
| "#dbdb8d", "#9edae5", | |
| ] | |
| ) | |
| fam_colour = {f: palette[i % len(palette)] for i, f in enumerate(fams)} | |
| traces = [] | |
| for fam in fams: | |
| mask = (ref_meta["family"] == fam).values | |
| sub = ref_meta[mask] | |
| hover = [f"<b>{r['acc']}</b><br>{r['name']}<br>{r['organism']}<br>family: {fam}" | |
| for _, r in sub.iterrows()] | |
| traces.append(go.Scatter3d( | |
| x=xyz[mask, 0], y=xyz[mask, 1], z=xyz[mask, 2], | |
| mode="markers", | |
| marker=dict(size=4, color=fam_colour[fam], line=dict(width=0.5, color="black")), | |
| name=fam, legendgroup="ref", legendgrouptitle_text="reference (families)", | |
| hovertext=hover, hoverinfo="text", | |
| )) | |
| # Unknowns: approximate 3-D position as weighted centroid of top-5 reference 3-D coords | |
| if gallery_data["records"]: | |
| ux, uy, uz, htxt, ucol = [], [], [], [], [] | |
| conf_colour = {"low": "#dc2626", "medium": "#d97706", "high": "#059669"} | |
| for r in gallery_data["records"]: | |
| result = r[mkey] | |
| top_ids = [t["id"] for t in result["top_10"][:5]] | |
| top_sims = np.array([t["cosine"] for t in result["top_10"][:5]]) | |
| idxs = [] | |
| for tid in top_ids: | |
| m = ref_meta.index[ref_meta["acc"] == tid].tolist() | |
| if m: | |
| idxs.append(m[0]) | |
| if len(idxs) < 3: | |
| continue | |
| w = np.exp((top_sims[:len(idxs)] - top_sims[:len(idxs)].max()) * 8.0) | |
| w /= w.sum() | |
| pos = (xyz[idxs] * w[:, None]).sum(axis=0) | |
| ux.append(pos[0]); uy.append(pos[1]); uz.append(pos[2]) | |
| verdict_short = result["verdict"].split(";")[0][:80] | |
| htxt.append(f"<b>{r['acc']}</b> ({result['confidence'].upper()})<br>" | |
| f"{r['name']}<br>{r['organism']}<br>" | |
| f"<b>verdict:</b> {verdict_short}<br>" | |
| f"top-1: {result['top_10'][0]['family']} (cos = {result['top_cosine']:+.3f})") | |
| ucol.append(conf_colour[result["confidence"]]) | |
| traces.append(go.Scatter3d( | |
| x=ux, y=uy, z=uz, mode="markers", | |
| marker=dict(size=9, symbol="diamond", color=ucol, | |
| line=dict(width=1.5, color="gold")), | |
| name="unknowns (by confidence)", legendgroup="unk", | |
| legendgrouptitle_text="unknowns (verdict confidence)", | |
| hovertext=htxt, hoverinfo="text", | |
| )) | |
| fig = go.Figure(data=traces) | |
| fig.update_layout( | |
| scene=dict(xaxis_title="UMAP-1", yaxis_title="UMAP-2", zaxis_title="UMAP-3"), | |
| height=600, margin=dict(l=0, r=0, t=30, b=0), | |
| title=f"3-D UMAP · reference (84) + unknowns (24) · {method_choice}", | |
| legend=dict(itemsizing="constant", groupclick="toggleitem"), | |
| ) | |
| return fig | |
| def _make_gallery_table(method_choice): | |
| mkey = "twin-mf" if method_choice.startswith("genomenet") else "esm2" | |
| rows = [] | |
| # Sort by top cosine descending | |
| recs_sorted = sorted(gallery_data["records"], key=lambda r: -r[mkey]["top_cosine"]) | |
| for r in recs_sorted: | |
| res = r[mkey] | |
| conf = res["confidence"] | |
| col = {"high": "#059669", "medium": "#d97706", "low": "#dc2626"}[conf] | |
| top1 = res["top_10"][0] | |
| verdict_short = res["verdict"][:90] | |
| rows.append( | |
| f"<tr>" | |
| f"<td style='padding:6px 10px;font-family:ui-monospace,monospace;white-space:nowrap;'>" | |
| f"<a href='{r['uniprot']}' target='_blank' style='color:#2563eb;text-decoration:none;'>{r['acc']}</a></td>" | |
| f"<td style='padding:6px 10px;font-size:12px;color:#374151;'>{r['organism'][:42]}</td>" | |
| f"<td style='padding:6px 10px;font-size:12px;color:#6b7280;'>{r['length']}</td>" | |
| f"<td style='padding:6px 10px;font-size:12px;color:#6b7280;'>{r['source']}</td>" | |
| f"<td style='padding:6px 10px;'><span style='background:{col};color:white;" | |
| f"padding:2px 6px;border-radius:3px;font-size:11px;'>{conf.upper()}</span></td>" | |
| f"<td style='padding:6px 10px;font-size:12px;'>{verdict_short}</td>" | |
| f"<td style='padding:6px 10px;font-family:ui-monospace,monospace;font-size:12px;'>" | |
| f"{top1['family']} · {res['top_cosine']:+.3f}</td>" | |
| f"</tr>" | |
| ) | |
| return ( | |
| f"<h4 style='margin:12px 0 6px 0;'>Full gallery (sorted by top-1 cosine, " | |
| f"{method_choice})</h4>" | |
| f"<table style='width:100%;border-collapse:collapse;font-size:13px;'>" | |
| f"<thead><tr style='text-align:left;border-bottom:1px solid #e5e7eb;color:#6b7280;font-size:12px;'>" | |
| f"<th style='padding:6px 10px;'>UniProt</th>" | |
| f"<th style='padding:6px 10px;'>organism</th>" | |
| f"<th style='padding:6px 10px;'>aa</th>" | |
| f"<th style='padding:6px 10px;'>source</th>" | |
| f"<th style='padding:6px 10px;'>confidence</th>" | |
| f"<th style='padding:6px 10px;'>verdict</th>" | |
| f"<th style='padding:6px 10px;'>top hit · cosine</th>" | |
| f"</tr></thead>" | |
| f"<tbody>{''.join(rows)}</tbody></table>" | |
| f"<p style='font-size:11px;color:#6b7280;margin-top:8px;'>" | |
| f"Source queries used to fetch the unknowns: " | |
| f"phage hypotheticals ≤200 aa, phage hypotheticals 200–500 aa, " | |
| f"Streptococcus-phage and Pseudomonas-phage hypotheticals via " | |
| f"<code>virus_host_id</code> on UniProt, plus two bacterial-hypothetical " | |
| f"control sets from bacteria with experimental evidence (E. coli, B. subtilis, " | |
| f"M. tuberculosis). See " | |
| f"<code>scripts/analyses/crispr/precompute_unknowns.py</code>." | |
| f"</p>" | |
| ) | |
| gallery_method_radio.change( | |
| fn=lambda m: (_make_gallery_3d_plot(m), _make_gallery_table(m)), | |
| inputs=[gallery_method_radio], | |
| outputs=[gallery_plot, gallery_table], | |
| show_progress="minimal", | |
| ) | |
| gallery_load_btn.click( | |
| fn=lambda m: (_make_gallery_3d_plot(m), _make_gallery_table(m)), | |
| inputs=[gallery_method_radio], | |
| outputs=[gallery_plot, gallery_table], | |
| show_progress="minimal", | |
| ) | |
| # --- Section 2: bring-your-own-sequence ---------------------------- | |
| gr.Markdown("#### 2. Bring your own sequence") | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=320): | |
| char_seq_input = gr.Textbox( | |
| label="Protein sequence", | |
| placeholder="Paste protein sequence (amino acids)...", | |
| lines=6, | |
| value=EXAMPLE_PROTEIN, | |
| info="FASTA or raw sequence; > 1022 aa is truncated", | |
| ) | |
| char_method_radio = gr.Radio( | |
| choices=["genomenet-twin (MF)", "ESM2 baseline"], | |
| value="genomenet-twin (MF)", | |
| label="Method", | |
| info="Twin-MF is the recommended aspect for CRISPR classification.", | |
| ) | |
| char_btn = gr.Button("characterize", variant="primary") | |
| with gr.Column(scale=2, min_width=400): | |
| char_verdict = gr.Markdown() | |
| char_plot = gr.Plot(label="Query placement on CRISPR UMAP") | |
| char_hits = gr.HTML() | |
| def _characterize_handler(sequence, method_choice): | |
| sequence = strip_fasta_header(sequence.strip()) | |
| valid, err = validate_protein(sequence) | |
| if not valid: | |
| return (f"**Error**: {err}", None, "") | |
| trunc_note = (f"> Query truncated from {len(sequence)} to {ESM2_MAX_LEN} aa " | |
| f"(ESM2 limit).\n\n" if len(sequence) > ESM2_MAX_LEN else "") | |
| if method_choice.startswith("genomenet-twin"): | |
| query_emb = embed_twin(sequence, aspect="MF") | |
| method_key = "Twin-MF" | |
| threshold = 0.70 | |
| method_label = "genomenet-twin (MF)" | |
| else: | |
| query_emb = embed_esm2(sequence) | |
| method_key = "ESM2 baseline" | |
| threshold = 0.90 | |
| method_label = "ESM2 baseline" | |
| # Rank the reference | |
| ref_emb, ref_meta = get_crispr_reference(method_key) | |
| q = query_emb / (np.linalg.norm(query_emb) + 1e-9) | |
| sims = ref_emb @ q | |
| order = np.argsort(-sims) | |
| top = order[:10] | |
| top_rows = [] | |
| for rank, idx in enumerate(top, 1): | |
| m = ref_meta.iloc[idx] | |
| top_rows.append({ | |
| "rank": rank, "id": m["acc"], "cosine": float(sims[idx]), | |
| "family": m["family"], "group": m["group"], | |
| "organism": m["organism"], "name": m["name"], | |
| }) | |
| top_cos = top_rows[0]["cosine"] | |
| label, conf = _crispr_verdict(top_rows, top_cos, threshold) | |
| conf_colour = {"high": "#059669", "medium": "#d97706", "low": "#dc2626"}.get(conf, "#6b7280") | |
| verdict_md = ( | |
| f"{trunc_note}" | |
| f"### Verdict ({method_label})\n" | |
| f"<div style='border-left:4px solid {conf_colour};padding:10px 14px;" | |
| f"background:#fafafa;border-radius:4px;margin:6px 0;'>" | |
| f"<div style='font-size:15px;'><b>{label}</b></div>" | |
| f"<div style='font-size:12px;color:#6b7280;margin-top:4px;'>" | |
| f"confidence: <b style='color:{conf_colour}'>{conf.upper()}</b> · " | |
| f"top-1 cosine = {top_cos:+.3f} · threshold = {threshold:.2f}" | |
| f"</div></div>" | |
| ) | |
| # UMAP scatter (only available for ESM2 baseline and Twin-MF) | |
| xy_ref = get_crispr_umap(method_key) | |
| fig = None | |
| if xy_ref is not None: | |
| # Weighted centroid of top-5 neighbours | |
| w = np.exp((sims[top[:5]] - sims[top[:5]].max()) * 8.0) | |
| w /= w.sum() | |
| q_xy = (xy_ref[top[:5]] * w[:, None]).sum(axis=0) | |
| import matplotlib.pyplot as plt | |
| from matplotlib.lines import Line2D | |
| all_fams = list(dict.fromkeys(ref_meta["family"].tolist())) | |
| cmap = plt.get_cmap("tab20", len(all_fams)) | |
| fam_colour = {f: cmap(i) for i, f in enumerate(all_fams)} | |
| group_marker = {"cas": "o", "acr": "^"} | |
| fig, ax = plt.subplots(figsize=(9, 6.5)) | |
| for i, r in ref_meta.iterrows(): | |
| ax.scatter(xy_ref[i, 0], xy_ref[i, 1], | |
| color=fam_colour[r["family"]], | |
| marker=group_marker.get(r["group"], "o"), | |
| s=55, edgecolor="black", linewidth=0.4, alpha=0.85, zorder=2) | |
| # Lines to top-3 neighbours | |
| for i in top[:3]: | |
| ax.plot([q_xy[0], xy_ref[i, 0]], [q_xy[1], xy_ref[i, 1]], | |
| color="#444", lw=0.6, linestyle="--", alpha=0.6, zorder=3) | |
| # Query star | |
| ax.scatter(q_xy[0], q_xy[1], marker="*", s=320, color="#111", | |
| edgecolor="gold", linewidth=2.0, zorder=10) | |
| handles = [Line2D([0], [0], marker="o", color="w", markerfacecolor=fam_colour[f], | |
| markeredgecolor="black", markersize=8, label=f) for f in all_fams] | |
| handles.append(Line2D([0], [0], marker="*", color="w", markerfacecolor="#111", | |
| markeredgecolor="gold", markersize=14, label="query")) | |
| ax.legend(handles=handles, loc="center left", bbox_to_anchor=(1.02, 0.5), | |
| fontsize=7.5, frameon=False) | |
| ax.set_xlabel("UMAP-1"); ax.set_ylabel("UMAP-2") | |
| ax.set_title(f"Query placed at softmax-weighted top-5 centroid ({method_label})") | |
| ax.grid(alpha=0.25, linestyle=":") | |
| plt.tight_layout() | |
| # Top-10 HTML table | |
| all_cos = [r["cosine"] for r in top_rows] | |
| lo, hi = min(all_cos), max(all_cos) | |
| def _row(r): | |
| pct = max(0.0, min(100.0, (r["cosine"] - lo) / (hi - lo + 1e-9) * 100.0)) | |
| link = (f"<a href='https://www.uniprot.org/uniprotkb/{r['id']}' " | |
| f"target='_blank' style='color:#2563eb;text-decoration:none;'>↗</a>") | |
| bar = (f"<div style='display:flex;align-items:center;gap:10px;'>" | |
| f"<div style='position:relative;width:160px;height:10px;" | |
| f"background:linear-gradient(to right,#f87171,#facc15,#4ade80);border-radius:2px;'>" | |
| f"<div style='position:absolute;left:{pct:.1f}%;top:-3px;width:2px;height:16px;" | |
| f"background:#111;box-shadow:0 0 0 2px #fff;transform:translateX(-50%);'></div>" | |
| f"</div>" | |
| f"<span style='font-family:ui-monospace,monospace;font-size:12px;min-width:54px;'>{r['cosine']:+.4f}</span>" | |
| f"</div>") | |
| return (f"<tr><td style='padding:5px 10px;color:#555;'>{r['rank']}</td>" | |
| f"<td style='padding:5px 10px;font-family:ui-monospace,monospace;'>{r['id']}</td>" | |
| f"<td style='padding:5px 10px;'>{bar}</td>" | |
| f"<td style='padding:5px 10px;font-weight:600;'>{r['family']}</td>" | |
| f"<td style='padding:5px 10px;font-size:12px;color:#374151;'>{r['organism'][:50]}</td>" | |
| f"<td style='padding:5px 10px;'>{link}</td></tr>") | |
| table = (f"<h4 style='margin:12px 0 6px 0;'>Top-10 nearest reference proteins</h4>" | |
| f"<table style='width:100%;border-collapse:collapse;font-size:13px;'>" | |
| f"<thead><tr style='text-align:left;border-bottom:1px solid #e5e7eb;color:#6b7280;font-size:12px;'>" | |
| f"<th style='padding:5px 10px;'>#</th><th style='padding:5px 10px;'>ID</th>" | |
| f"<th style='padding:5px 10px;'>cosine</th><th style='padding:5px 10px;'>family</th>" | |
| f"<th style='padding:5px 10px;'>organism</th><th style='padding:5px 10px;'></th></tr></thead>" | |
| f"<tbody>{''.join(_row(r) for r in top_rows)}</tbody></table>") | |
| return verdict_md, fig, table | |
| char_btn.click( | |
| lambda m: (f"<div style='padding:16px;color:#666;'>⏳ Characterizing against CRISPR reference with {m}…</div>", None, ""), | |
| inputs=[char_method_radio], | |
| outputs=[char_verdict, char_plot, char_hits], | |
| show_progress="hidden", | |
| ).then( | |
| _characterize_handler, | |
| inputs=[char_seq_input, char_method_radio], | |
| outputs=[char_verdict, char_plot, char_hits], | |
| api_name="characterize", | |
| show_progress="minimal", | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [_SPCAS9, "genomenet-twin (MF)"], | |
| [_FNCAS12A, "genomenet-twin (MF)"], | |
| [_LSHCAS13A, "genomenet-twin (MF)"], | |
| [_CAS1_ECOLI, "genomenet-twin (MF)"], | |
| [_ACRIIA4, "genomenet-twin (MF)"], | |
| [_SPCAS9, "ESM2 baseline"], | |
| [_ACRIIA4, "ESM2 baseline"], | |
| ], | |
| inputs=[char_seq_input, char_method_radio], | |
| example_labels=[ | |
| "SpCas9 — expected: Cas9 (Type II-A effector)", | |
| "FnCas12a — expected: Cas12 (Type V-A effector)", | |
| "LshCas13a — expected: Cas13 (Type VI-A, RNA-targeting)", | |
| "Cas1 (E. coli) — expected: Cas1 (adaptation)", | |
| "AcrIIA4 — expected: AcrIIA family (anti-Cas9)", | |
| "SpCas9 — same query, ESM2 method (compare)", | |
| "AcrIIA4 — same query, ESM2 method (compare)", | |
| ], | |
| label="Example queries — each known protein; expected verdict is in the label", | |
| ) | |
| with gr.Tab("Benchmark"): | |
| gr.Markdown(""" | |
| ## CRISPR protein benchmark | |
| **Question.** Does the GO-supervised `genomenet-twin` embedding recover useful | |
| CRISPR protein structure beyond the pretrained ESM2 representation? | |
| This tab evaluates only **protein sequences**. It does not use the CRISPR array | |
| detector and does not run DIAMOND, MMseqs, BLAST, or any other sequence-search | |
| baseline. The benchmark asks whether cosine geometry in each embedding space | |
| recovers known CRISPR protein labels. | |
| The first UniProt mixed benchmark is now superseded: broad text queries pulled | |
| known contaminants, especially acriflavin-resistance proteins from `AcrIF*`, and | |
| a few Cas entries were mislabeled. The current benchmark therefore separates the | |
| problem into clean reference views and then recombines them for the main figure. | |
| ### Reference sets | |
| | reference view | proteins | source and purpose | | |
| |---|---:|---| | |
| | **Mixed Cas + Acr** | **254** | Main overview. Concatenates 175 Pfam-verified Cas proteins with 79 Anti-CRISPRdb v3 Acr proteins. | | |
| | **Acr-only** | **79** | Anti-CRISPR family retrieval. Uses typed Anti-CRISPRdb v3 records plus local curated AcrIB examples. | | |
| | **Cas-only** | **175** | Within-Cas subfamily resolution. Uses UniProt Pfam cross-references for Cas1-Cas12. | | |
| ### How to read the metrics | |
| | metric | interpretation | | |
| |---|---| | |
| | **within − between** | Mean within-family cosine minus between-family cosine. High values mean block structure is strong, but this is not by itself a retrieval metric. | | |
| | **5-NN recall** | For each protein, fraction of its nearest neighbours that share the family label, adjusted for small families. This is the practical "find more proteins like this" metric. | | |
| | **LOO top-1** | Leave-one-out nearest-neighbour family assignment accuracy. This asks whether the closest other protein has the correct family. | | |
| | **LOO MRR** | Leave-one-out mean reciprocal rank of the first correct-family neighbour. Higher means correct-family proteins appear earlier in the ranked list. | | |
| | **AUC family** | Pairwise same-family versus different-family discrimination from cosine similarity. | | |
| | **AUC group** | Pairwise broader-group discrimination. In the mixed reference this is Cas versus Acr; in the Acr-only reference it is inhibited CRISPR class. | | |
| | **silhouette** | Scale-normalized cluster quality. Positive values mean proteins are closer to their own family than to neighbouring families. | | |
| The important point is that these metrics answer different biological questions. | |
| Nearest-neighbour metrics evaluate annotation and retrieval. Pairwise AUCs test | |
| global ranking of same-label pairs. Silhouette tests cluster geometry. A single | |
| "winner" is therefore less informative than the pattern across metrics. | |
| """) | |
| gr.Markdown(""" | |
| --- | |
| ## Figure 1. Mixed Cas + anti-CRISPR benchmark | |
| The mixed reference combines Pfam-verified Cas proteins and Anti-CRISPRdb v3 Acr | |
| proteins without re-embedding. This is the closest view to a general CRISPR | |
| protein benchmark: the model must separate Cas from Acr while also preserving | |
| family-level structure inside each group. | |
| """) | |
| gr.Image(value="data/benchmark/crispr_umap_esm2_vs_twin.png", label="Mixed reference projection", show_label=True, container=False) | |
| gr.Markdown(""" | |
| **Figure 1a. Two-dimensional projection of the mixed reference.** Each point is | |
| one protein and colours encode family labels. The projection is a visualization | |
| aid, not the primary statistic: distances are distorted by dimensionality | |
| reduction. The useful readout is whether same-label proteins form coherent | |
| neighbourhoods and whether Cas and Acr regions are visibly separated. | |
| """) | |
| gr.Image(value="data/benchmark/crispr_esm2_vs_twin_heatmap.png", label="Mixed reference heatmap", show_label=True, container=False) | |
| gr.Markdown(""" | |
| **Figure 1b. All-by-all cosine similarity heatmap.** Rows and columns are | |
| ordered by family. A strong embedding produces diagonal blocks: high similarity | |
| within a family and lower similarity between unrelated families. Twin-MF gives a | |
| much stronger within-minus-between contrast than ESM2, but the retrieval tables | |
| below are needed to decide whether that block structure helps annotation. | |
| """) | |
| gr.Image(value="data/benchmark/crispr_aspect_comparison.png", label="Mixed reference model comparison", show_label=True, container=False) | |
| gr.Markdown(""" | |
| **Table 1. Quantitative performance on the mixed reference.** | |
| | model | within − between | 5-NN recall | LOO top-1 | LOO MRR | AUC family | AUC Cas-vs-Acr | silhouette | | |
| |---|---:|---:|---:|---:|---:|---:|---:| | |
| | ESM2 baseline | 0.038 | 0.608 | 0.724 | 0.795 | 0.823 | 0.700 | **0.131** | | |
| | Twin-BP | 0.169 | **0.618** | 0.697 | 0.771 | 0.829 | 0.689 | 0.082 | | |
| | Twin-CC | **0.404** | 0.480 | 0.618 | 0.713 | 0.784 | 0.554 | 0.016 | | |
| | Twin-MF | 0.321 | 0.604 | **0.748** | **0.808** | **0.843** | **0.711** | 0.084 | | |
| **Interpretation.** Twin-MF is the strongest aspect for family assignment by | |
| leave-one-out nearest-neighbour tests and for pairwise family or Cas-vs-Acr | |
| separation. Twin-BP is marginally best for 5-neighbour recall. ESM2 has the best | |
| silhouette, meaning its clusters are more compact relative to their nearest | |
| alternative cluster in this reference. This is a mixed result: Twin-MF improves | |
| several annotation-oriented ranking metrics, but ESM2 remains competitive and | |
| should stay as the baseline in any manuscript claim. | |
| """) | |
| gr.Image(value="data/benchmark/crispr_subcluster_comparison.png", label="Mixed reference family retrieval detail", show_label=True, container=False) | |
| gr.Markdown(""" | |
| **Figure 1c. Per-family nearest-neighbour behaviour.** The lower panel shows | |
| where the aggregate metrics come from. Twin-MF improves retrieval for several | |
| Cas families, including Cas3 and Cas4, but loses for others such as Cas8 and | |
| some Acr families. This argues against reporting only a global average: the | |
| embedding improvement is family-dependent. | |
| """) | |
| gr.Image(value="data/benchmark/crispr_silhouette.png", label="Mixed reference silhouette", show_label=True, container=False) | |
| gr.Markdown(""" | |
| **Figure 1d. Silhouette analysis.** Silhouette scores summarize whether each | |
| protein is closer to its own family than to the nearest other family. ESM2 has | |
| the highest mixed-reference mean silhouette, while Twin-MF has better | |
| leave-one-out family assignment. In manuscript language, this means Twin-MF | |
| improves ranking of the first correct family hit but does not create uniformly | |
| cleaner clusters. | |
| """) | |
| gr.Image(value="data/benchmark/crispr_roc_esm2_vs_twin.png", label="Mixed reference ROC", show_label=True, container=False) | |
| gr.Markdown(""" | |
| **Figure 1e. Pairwise ROC curves.** ROC curves evaluate all protein pairs rather | |
| than query-level retrieval. Twin-MF has the best overall family AUC in Table 1, | |
| whereas this ROC panel compares ESM2 and Twin-BP specifically. Pairwise AUC is | |
| useful for global separation, but nearest-neighbour metrics are more directly | |
| linked to annotation workflows. | |
| """) | |
| gr.Markdown(""" | |
| --- | |
| ## Figure 2. Acr-only benchmark | |
| The Acr-only reference asks a narrower question: can the embeddings recover | |
| known anti-CRISPR family labels? This is harder than simple Cas-versus-Acr | |
| separation because many Acr families are short, diverse, and defined by target | |
| system rather than shared fold. | |
| """) | |
| gr.Image(value="data/benchmark/acrdb_v3/crispr_aspect_comparison.png", label="Acr-only model comparison", show_label=True, container=False) | |
| gr.Markdown(""" | |
| **Figure 2a and Table 2. Anti-CRISPRdb v3 family benchmark.** | |
| """) | |
| gr.Image(value="data/benchmark/acrdb_v3/crispr_subcluster_comparison.png", label="Acr-only retrieval detail", show_label=True, container=False) | |
| gr.Markdown(""" | |
| | model | within − between | 5-NN recall | LOO top-1 | LOO MRR | AUC family | AUC inhibited class | silhouette | | |
| |---|---:|---:|---:|---:|---:|---:|---:| | |
| | ESM2 baseline | 0.008 | **0.306** | **0.506** | **0.631** | 0.658 | 0.612 | **-0.022** | | |
| | Twin-BP | 0.103 | 0.256 | 0.380 | 0.523 | 0.657 | 0.529 | -0.075 | | |
| | Twin-CC | 0.157 | 0.294 | 0.418 | 0.557 | **0.708** | 0.577 | -0.040 | | |
| | Twin-MF | **0.185** | 0.284 | 0.481 | 0.606 | 0.701 | **0.646** | -0.075 | | |
| **Interpretation.** Acr-only retrieval currently favours ESM2. Twin-CC and | |
| Twin-MF improve pairwise AUCs, but they do not improve the practical | |
| nearest-neighbour task. This may reflect real biology: anti-CRISPR family names | |
| often group proteins by inhibited CRISPR system rather than by a single | |
| evolutionary or structural family. For Acr discovery, Twin scores may still be | |
| useful as an additional signal, but the current benchmark does not justify | |
| replacing ESM2 for Acr family lookup. | |
| """) | |
| gr.Markdown(""" | |
| --- | |
| ## Figure 3. Cas-only benchmark | |
| The Cas-only reference removes the easy Cas-versus-Acr split and tests | |
| within-Cas subfamily resolution. Proteins were selected by Pfam cross-reference, | |
| not by broad protein-name search, to reduce label noise. | |
| """) | |
| gr.Image(value="data/benchmark/cas_pfam/crispr_umap_esm2_vs_twin.png", label="Cas-only projection", show_label=True, container=False) | |
| gr.Markdown(""" | |
| **Figure 3a. Cas-only two-dimensional projection.** The plot visualizes whether | |
| Cas families occupy separate regions. It should be read together with Table 3, | |
| because projection layout can exaggerate or hide neighbourhood structure. | |
| """) | |
| gr.Image(value="data/benchmark/cas_pfam/crispr_silhouette.png", label="Cas-only silhouette", show_label=True, container=False) | |
| gr.Markdown(""" | |
| **Figure 3b. Cas-only silhouette.** ESM2 produces the best cluster compactness | |
| on this reference, while Twin-MF performs best on top-hit family assignment. | |
| This difference is important: compact global clusters and the identity of the | |
| nearest correct family hit are related but not identical objectives. | |
| """) | |
| gr.Image(value="data/benchmark/cas_pfam/crispr_aspect_comparison.png", label="Cas-only model comparison", show_label=True, container=False) | |
| gr.Markdown(""" | |
| **Table 3. Cas-only performance.** | |
| | model | within − between | 5-NN recall | LOO top-1 | LOO MRR | AUC family | silhouette | | |
| |---|---:|---:|---:|---:|---:|---:| | |
| | ESM2 baseline | 0.033 | 0.763 | 0.834 | 0.887 | 0.798 | **0.221** | | |
| | Twin-BP | 0.171 | **0.802** | 0.851 | 0.896 | **0.842** | 0.206 | | |
| | Twin-CC | **0.407** | 0.610 | 0.749 | 0.824 | 0.791 | 0.093 | | |
| | Twin-MF | 0.274 | 0.759 | **0.874** | **0.911** | 0.821 | 0.192 | | |
| **Interpretation.** Twin-MF is the best model when the task is leave-one-out | |
| Cas family assignment: the first correct-family neighbour appears earlier and | |
| the nearest neighbour is correct more often. Twin-BP is best for 5-NN recall and | |
| pairwise family AUC. ESM2 is best for silhouette. Twin-CC has high | |
| within-minus-between cosine but poor retrieval, so it should not be used as the | |
| default CRISPR aspect. | |
| ### Overall conclusion | |
| For a manuscript, the defensible conclusion is nuanced: | |
| - **Cas annotation:** Twin-MF is useful for top-hit family assignment; Twin-BP is | |
| useful for broader neighbour recall. | |
| - **Acr annotation:** ESM2 remains the stronger nearest-neighbour baseline on | |
| the current Anti-CRISPRdb v3 reference. | |
| - **Mixed CRISPR retrieval:** Twin-MF improves several ranking metrics, but ESM2 | |
| has better silhouette. Report both rather than claiming a universal win. | |
| - **Aspect choice:** CC is consistently weak for retrieval; BP and MF are the | |
| only plausible Twin aspects for CRISPR proteins. | |
| The practical recommendation is to expose both ESM2 and Twin-MF/BP in the tool: | |
| ESM2 as the conservative sequence-representation baseline, Twin-MF for Cas-like | |
| top-hit annotation, and Twin-BP when broader same-family neighbour retrieval is | |
| the goal. | |
| """) | |
| gr.Markdown(""" | |
| --- | |
| ## Figure 4. Applying the models to unknown proteins | |
| The benchmark above uses proteins with known Cas or Acr labels. As a more | |
| realistic discovery-style check, we also screen uncharacterised proteins from a | |
| larger phage/bacterial/archaeal candidate FASTA and score each query against the | |
| same mixed Cas/Acr reference. This is not a validated discovery set; it is a | |
| triage view that asks which unknown proteins land closest to known CRISPR | |
| families under ESM2 versus Twin-MF. Because these proteins are unlabeled, the | |
| screen does **not** decide which embedding is correct. The distribution plot | |
| shows all scores, but the candidate list below reports only high-confidence | |
| hits where the nearest-neighbour evidence is strongest. | |
| """) | |
| unknown_data = get_unknown_results() | |
| n_unknown = unknown_data.get("n_queries", 0) | |
| unknown_records = unknown_data.get("records", []) | |
| if n_unknown > 0: | |
| def _unknown_count_by_verdict(method_key): | |
| pred = sum(1 for r in unknown_records if str(r[method_key]["verdict"]).startswith("Predicted")) | |
| signal = sum(1 for r in unknown_records if r[method_key]["confidence"] != "low") | |
| high = sum(1 for r in unknown_records if r[method_key]["confidence"] == "high") | |
| return pred, signal, high | |
| pred_mf, signal_mf, high_mf = _unknown_count_by_verdict("twin-mf") | |
| pred_es, signal_es, high_es = _unknown_count_by_verdict("esm2") | |
| gr.Markdown( | |
| f"**Precomputed screen.** {n_unknown} uncharacterised proteins were embedded " | |
| f"with ESM2 and Twin-MF, then ranked against the mixed 254-protein CRISPR " | |
| f"reference.\n\n" | |
| f"| method | family-level verdict | any above-threshold CRISPR signal | high-confidence calls |\n" | |
| f"|---|---:|---:|---:|\n" | |
| f"| Twin-MF | **{pred_mf}** / {n_unknown} | **{signal_mf}** / {n_unknown} | **{high_mf}** / {n_unknown} |\n" | |
| f"| ESM2 baseline | {pred_es} / {n_unknown} | {signal_es} / {n_unknown} | {high_es} / {n_unknown} |\n\n" | |
| f"The first two columns are diagnostic counts. Only the high-confidence " | |
| f"candidates are listed in the table, with protein name and sequence, because " | |
| f"medium-confidence calls are useful for exploration but too weak to present " | |
| f"as report-level candidates. High cosine and agreement among top neighbours " | |
| f"make a protein worth follow-up; they do not prove CRISPR function." | |
| ) | |
| else: | |
| gr.Markdown( | |
| "**Precomputed screen unavailable.** Rebuild it with " | |
| "`scripts/analyses/crispr/precompute_unknowns.py --ref_dir data/crispr_reference_mixed_v3`." | |
| ) | |
| gr.Image( | |
| value="data/benchmark/crispr_unknowns_distribution.png", | |
| label="Unknown-protein screen score distribution", | |
| show_label=True, | |
| container=False, | |
| ) | |
| with gr.Row(): | |
| unknown_method_radio = gr.Radio( | |
| choices=["genomenet-twin (MF)", "ESM2 baseline"], | |
| value="genomenet-twin (MF)", | |
| label="Unknown-screen view", | |
| info="Controls the candidate table and 3-D placement overview.", | |
| ) | |
| unknown_load_btn = gr.Button("load unknown screen", variant="secondary") | |
| unknown_plot = gr.Plot(label="3-D reference projection + unknown-protein candidates") | |
| unknown_table = gr.HTML( | |
| "<div style='padding:12px;color:#6b7280;'>Click <b>load unknown screen</b> to render the high-confidence candidate list.</div>" | |
| ) | |
| def _make_unknown_screen_plot(method_choice): | |
| method_key = "Twin-MF" if method_choice.startswith("genomenet") else "ESM2 baseline" | |
| result_key = "twin-mf" if method_key == "Twin-MF" else "esm2" | |
| xyz = get_crispr_umap_3d(method_key) | |
| if xyz is None or not unknown_records: | |
| return go.Figure() | |
| _, ref_meta = get_crispr_reference(method_key) | |
| fams = list(dict.fromkeys(ref_meta["family"].tolist())) | |
| palette = [ | |
| "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", | |
| "#e377c2", "#7f7f7f", "#bcbd22", "#17becf", "#aec7e8", "#ffbb78", | |
| "#98df8a", "#ff9896", "#c5b0d5", "#c49c94", "#f7b6d3", "#c7c7c7", | |
| "#dbdb8d", "#9edae5", | |
| ] | |
| fam_colour = {f: palette[i % len(palette)] for i, f in enumerate(fams)} | |
| traces = [] | |
| for fam in fams: | |
| mask = (ref_meta["family"] == fam).values | |
| sub = ref_meta[mask] | |
| hover = [ | |
| f"<b>{html.escape(_crispr_display_id(r))}</b><br>" | |
| f"{html.escape(_cell_text(r.get('name', '')))}<br>" | |
| f"{html.escape(_cell_text(r.get('organism', '')))}<br>" | |
| f"family: {html.escape(str(fam))}" | |
| for _, r in sub.iterrows() | |
| ] | |
| traces.append(go.Scatter3d( | |
| x=xyz[mask, 0], y=xyz[mask, 1], z=xyz[mask, 2], | |
| mode="markers", | |
| marker=dict(size=4, color=fam_colour[fam], line=dict(width=0.5, color="black")), | |
| name=fam, legendgroup="ref", legendgrouptitle_text="reference families", | |
| hovertext=hover, hoverinfo="text", | |
| )) | |
| ranked_records = [ | |
| r for r in sorted(unknown_records, key=lambda r: -r[result_key]["top_cosine"]) | |
| if r[result_key]["confidence"] == "high" | |
| ] | |
| ux, uy, uz, htxt, ucol = [], [], [], [], [] | |
| conf_colour = {"low": "#dc2626", "medium": "#d97706", "high": "#059669"} | |
| ref_ids = ref_meta["acc"].astype(str).tolist() | |
| ref_id_to_idx = {rid: i for i, rid in enumerate(ref_ids)} | |
| display_to_idx = {_crispr_display_id(row): i for i, row in ref_meta.iterrows()} | |
| for r in ranked_records: | |
| result = r[result_key] | |
| idxs, top_sims = [], [] | |
| for t in result["top_10"][:5]: | |
| tid = str(t["id"]) | |
| idx = ref_id_to_idx.get(tid, display_to_idx.get(tid)) | |
| if idx is not None: | |
| idxs.append(idx) | |
| top_sims.append(float(t["cosine"])) | |
| if len(idxs) < 3: | |
| continue | |
| top_sims = np.array(top_sims, dtype=np.float32) | |
| w = np.exp((top_sims - top_sims.max()) * 8.0) | |
| w /= w.sum() | |
| pos = (xyz[idxs] * w[:, None]).sum(axis=0) | |
| ux.append(pos[0]); uy.append(pos[1]); uz.append(pos[2]) | |
| top1 = result["top_10"][0] | |
| verdict_short = str(result["verdict"]).split(";")[0][:90] | |
| htxt.append( | |
| f"<b>{html.escape(str(r['acc']))}</b> ({result['confidence'].upper()})<br>" | |
| f"{html.escape(str(r.get('organism', '')))}<br>" | |
| f"<b>verdict:</b> {html.escape(verdict_short)}<br>" | |
| f"top-1: {html.escape(str(top1['family']))} (cos = {result['top_cosine']:+.3f})" | |
| ) | |
| ucol.append(conf_colour[result["confidence"]]) | |
| traces.append(go.Scatter3d( | |
| x=ux, y=uy, z=uz, mode="markers", | |
| marker=dict(size=6, symbol="diamond", color=ucol, line=dict(width=1.0, color="gold")), | |
| name=f"high-confidence unknowns ({len(ux)})", | |
| legendgroup="unknowns", | |
| legendgrouptitle_text="unknown candidates", | |
| hovertext=htxt, hoverinfo="text", | |
| )) | |
| fig = go.Figure(data=traces) | |
| fig.update_layout( | |
| scene=dict(xaxis_title="projection-1", yaxis_title="projection-2", zaxis_title="projection-3"), | |
| height=620, margin=dict(l=0, r=0, t=35, b=0), | |
| title=f"Mixed CRISPR reference + high-confidence unknown candidates · {method_choice}", | |
| legend=dict(itemsizing="constant", groupclick="toggleitem"), | |
| ) | |
| return fig | |
| def _make_unknown_screen_table(method_choice): | |
| if not unknown_records: | |
| return "<div style='padding:12px;color:#6b7280;'>No unknown-screen records packaged.</div>" | |
| result_key = "twin-mf" if method_choice.startswith("genomenet") else "esm2" | |
| recs_sorted = [ | |
| r for r in sorted(unknown_records, key=lambda r: -r[result_key]["top_cosine"]) | |
| if r[result_key]["confidence"] == "high" | |
| ] | |
| if not recs_sorted: | |
| return ( | |
| f"<div style='padding:12px;color:#6b7280;'>No high-confidence unknown-protein " | |
| f"candidates for {html.escape(method_choice)} under the current thresholds. " | |
| f"Use the distribution plot above to compare lower-confidence score behaviour.</div>" | |
| ) | |
| seqs = get_unknown_sequences() | |
| shown = recs_sorted | |
| rows = [] | |
| for r in shown: | |
| res = r[result_key] | |
| conf = res["confidence"] | |
| col = {"high": "#059669", "medium": "#d97706", "low": "#dc2626"}[conf] | |
| top1 = res["top_10"][0] | |
| verdict_short = html.escape(str(res["verdict"])[:110]) | |
| acc = html.escape(str(r["acc"])) | |
| name = html.escape(str(r.get("name", ""))[:80]) | |
| org = html.escape(str(r.get("organism", ""))[:48]) | |
| source = html.escape(str(r.get("source", ""))) | |
| top_family = html.escape(str(top1["family"])) | |
| seq = html.escape(_wrap_sequence_for_html(seqs.get(str(r["acc"]), ""), width=80)) | |
| sequence_block = ( | |
| f"<details><summary style='cursor:pointer;color:#2563eb;'>sequence</summary>" | |
| f"<pre style='white-space:pre-wrap;word-break:break-word;font-size:11px;" | |
| f"line-height:1.35;background:#f9fafb;border:1px solid #e5e7eb;" | |
| f"border-radius:4px;padding:8px;margin:6px 0 0 0;'>{seq}</pre></details>" | |
| if seq else "" | |
| ) | |
| rows.append( | |
| f"<tr>" | |
| f"<td style='padding:6px 10px;font-family:ui-monospace,monospace;white-space:nowrap;'>" | |
| f"<a href='{r['uniprot']}' target='_blank' style='color:#2563eb;text-decoration:none;'>{acc}</a></td>" | |
| f"<td style='padding:6px 10px;font-size:12px;color:#374151;'>{name}{sequence_block}</td>" | |
| f"<td style='padding:6px 10px;font-size:12px;color:#374151;'>{org}</td>" | |
| f"<td style='padding:6px 10px;font-size:12px;color:#6b7280;'>{r['length']}</td>" | |
| f"<td style='padding:6px 10px;font-size:12px;color:#6b7280;'>{source}</td>" | |
| f"<td style='padding:6px 10px;'><span style='background:{col};color:white;" | |
| f"padding:2px 6px;border-radius:3px;font-size:11px;'>{conf.upper()}</span></td>" | |
| f"<td style='padding:6px 10px;font-size:12px;'>{verdict_short}</td>" | |
| f"<td style='padding:6px 10px;font-family:ui-monospace,monospace;font-size:12px;'>" | |
| f"{top_family} · {res['top_cosine']:+.3f}</td>" | |
| f"</tr>" | |
| ) | |
| return ( | |
| f"<h4 style='margin:12px 0 6px 0;'>High-confidence unknown-protein candidates " | |
| f"({len(recs_sorted)} of {len(unknown_records)} screened; {method_choice}, sorted by top-1 cosine)</h4>" | |
| f"<table style='width:100%;border-collapse:collapse;font-size:13px;'>" | |
| f"<thead><tr style='text-align:left;border-bottom:1px solid #e5e7eb;color:#6b7280;font-size:12px;'>" | |
| f"<th style='padding:6px 10px;'>UniProt</th>" | |
| f"<th style='padding:6px 10px;'>name and sequence</th>" | |
| f"<th style='padding:6px 10px;'>organism</th>" | |
| f"<th style='padding:6px 10px;'>aa</th>" | |
| f"<th style='padding:6px 10px;'>source</th>" | |
| f"<th style='padding:6px 10px;'>confidence</th>" | |
| f"<th style='padding:6px 10px;'>verdict</th>" | |
| f"<th style='padding:6px 10px;'>top hit · cosine</th>" | |
| f"</tr></thead><tbody>{''.join(rows)}</tbody></table>" | |
| ) | |
| unknown_method_radio.change( | |
| fn=lambda m: (_make_unknown_screen_plot(m), _make_unknown_screen_table(m)), | |
| inputs=[unknown_method_radio], | |
| outputs=[unknown_plot, unknown_table], | |
| show_progress="minimal", | |
| ) | |
| unknown_load_btn.click( | |
| fn=lambda m: (_make_unknown_screen_plot(m), _make_unknown_screen_table(m)), | |
| inputs=[unknown_method_radio], | |
| outputs=[unknown_plot, unknown_table], | |
| show_progress="minimal", | |
| ) | |
| with gr.Tab("API"): | |
| gr.Markdown(""" | |
| ### API | |
| ```python | |
| from gradio_client import Client | |
| import numpy as np | |
| client = Client("genomenet/functional-distance") | |
| result = client.predict( | |
| sequence="MALWMRLLPLLALLALWG...", # protein sequence | |
| top_k=10, | |
| twin_aspect="BP", # "BP" | "CC" | "MF" | |
| api_name="/compare" | |
| ) | |
| summary, esm2_path, twin_path, *plots, hits = result | |
| esm2_emb = np.load(esm2_path) # (1280,) | |
| twin_emb = np.load(twin_path) # (1024,) | |
| # hits: DataFrame with columns [rank, uniref50_id, cosine, uniprot] | |
| # Twin pairwise distance (the model's native trained task) | |
| distance_md = client.predict( | |
| seq_a="MALWMRLLPLLALLALWG...", | |
| seq_b="MGKISSLPTQLFKCCFCDFL...", | |
| aspect="BP", # "BP" | "CC" | "MF" | |
| api_name="/distance", | |
| ) | |
| ``` | |
| ### Nearest-neighbor search | |
| The lookup tab can query FAISS indexes over the GO-annotated UniRef50 subset: | |
| ESM2 uses 1280-dim ESM2 embeddings, and `genomenet-twin (BP)` uses 1024-dim | |
| Twin-BP embeddings. Both indexes use cosine similarity on L2-normalized vectors. | |
| Top-k hits link to the corresponding UniRef50 cluster on UniProt. | |
| ### Models | |
| | Model | Dimension | Description | | |
| |-------|-----------|-------------| | |
| | ESM2 | 1280 | `esm2_t33_650M_UR50D` pretrained on UniRef50 | | |
| | Twin | 1024 | Resnik-contrastive fine-tune; one checkpoint per GO aspect (BP/CC/MF) | | |
| ### Comparison | |
| The Twin model is trained to produce embeddings where functionally similar proteins | |
| (sharing GO terms) have similar embeddings. ESM2 is the pretrained baseline without | |
| this GO supervision. | |
| """) | |
| with gr.Tab("About"): | |
| gr.Markdown(""" | |
| ### ESM2 vs Twin | |
| **ESM2** (`esm2_t33_650M_UR50D`): | |
| - 650M parameter protein language model | |
| - Pretrained on UniRef50 with masked language modeling | |
| - General-purpose protein representation | |
| **Twin Network** (`train_point_{BP,CC,MF}_20251221_std_ft_bs32ga4`): | |
| - Two-tower contrastive encoder: custom AA Transformer + **fine-tuned** ESM2 backbone | |
| - **One checkpoint per GO aspect**: Biological Process (BP), Cellular Component (CC), | |
| Molecular Function (MF). Pick aspect via the Twin GO aspect radio button. | |
| - Trained on Resnik GO-semantic similarity within each aspect | |
| - Output: `concat(custom_proj, esm_proj)` → 1024-dim; L2 distance on L2-normalized | |
| embeddings ≈ functional distance in that aspect | |
| ### Gene Ontology | |
| GO has three aspects: | |
| - **MF** (Molecular Function): Biochemical activity | |
| - **BP** (Biological Process): Biological objective | |
| - **CC** (Cellular Component): Subcellular location | |
| ### Links | |
| - ESM2: [github.com/facebookresearch/esm](https://github.com/facebookresearch/esm) | |
| - BERT DNA: [genomenet/bert-embedding](https://huggingface.co/spaces/genomenet/bert-embedding) | |
| """) | |
| if __name__ == "__main__": | |
| print("Functional Distance - ESM2 vs Twin") | |
| print(f"Device: {get_device()}") | |
| print("Loading ESM2...") | |
| try: | |
| _ = get_esm2() | |
| print("ESM2 ready!") | |
| except Exception as e: | |
| print(f"ESM2 load failed (will load on first request): {e}") | |
| print(f"Loading ESM2 FAISS index from {ESM2_FAISS_REPO_ID}...") | |
| try: | |
| _ = get_faiss("esm2") | |
| except Exception as e: | |
| print(f"ESM2 FAISS load failed (will retry on first request): {e}") | |
| print(f"Loading default Twin aspect ({TWIN_DEFAULT_ASPECT}) from {TWIN_REPO_ID}...") | |
| try: | |
| _ = get_twin(TWIN_DEFAULT_ASPECT) | |
| print(f"Twin/{TWIN_DEFAULT_ASPECT} ready!") | |
| except Exception as e: | |
| print(f"Twin load failed (will retry on first request): {e}") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| allowed_paths=[ | |
| os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "benchmark"), | |
| ], | |
| theme=gr.themes.Base( | |
| primary_hue=gr.themes.colors.zinc, | |
| neutral_hue=gr.themes.colors.zinc, | |
| ) | |
| ) | |