Spaces:
Build error
Build error
File size: 2,597 Bytes
b1231eb 06617ae d0e50a8 06617ae b1231eb 7abf9a4 915c127 06617ae 915c127 06617ae d0e50a8 06617ae d0e50a8 06617ae 915c127 06617ae 6fa5243 06617ae 41e103f 06617ae d0e50a8 06617ae d6212ed b1231eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | import gradio as gr
from transformers import AutoModel
import torch
import json
import os
import time
os.environ["HF_HOME"] = "/tmp/hf"
os.environ["HF_HUB_CACHE"] = "/tmp/hf/hub"
SEQ = "SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
AAS = "ACDEFGHIKLMNPQRSTVWY"
ORACLE_UUID = "HaUuRwfE" # Finetuned / Local / GFP / 1D (70k examples)
DESIGN_UUID = "YoQkzoLD" # GFP DESIGN / Finetuned / Local / GFP / 1D (64 examples)
def compute_all_scores(metl, uuid, label=""):
"""Compute scores for all single mutations."""
print(f"Loading model {uuid} ({label})...")
metl.load_from_uuid(uuid)
metl.eval()
scores = []
t0 = time.time()
for pos in range(len(SEQ)):
row = []
for ai, aa in enumerate(AAS):
if aa == SEQ[pos]:
row.append(None)
continue
variant = [f"{SEQ[pos]}{pos}{aa}"]
encoded = metl.encoder.encode_variants(SEQ, variant)
with torch.no_grad():
pred = metl(torch.tensor(encoded))
score = pred.item() if pred.numel() == 1 else pred[0].item()
row.append(round(score, 4))
scores.append(row)
if (pos + 1) % 50 == 0:
elapsed = time.time() - t0
print(f" {label}: {pos+1}/{len(SEQ)} positions ({elapsed:.1f}s)")
print(f" {label}: Done in {time.time()-t0:.1f}s")
return scores
def build_html(oracle_scores, design_scores):
"""Read the frontend template and inject real scores."""
with open("frontend.html", "r") as f:
html = f.read()
html = html.replace("/*__ORACLE_SCORES__*/", json.dumps(oracle_scores))
html = html.replace("/*__DESIGN_SCORES__*/", json.dumps(design_scores))
return html
# ── Startup: load models and pre-compute ──
print("=" * 60)
print("GFP Mutation Explorer - Pre-computing scores...")
print("=" * 60)
metl = AutoModel.from_pretrained('gitter-lab/METL', trust_remote_code=True)
oracle_scores = compute_all_scores(metl, ORACLE_UUID, "Oracle (70k)")
design_scores = compute_all_scores(metl, DESIGN_UUID, "GFP-Design (64)")
print("\nBuilding frontend...")
html_content = build_html(oracle_scores, design_scores)
print(f"Frontend ready ({len(html_content)} chars)")
# ── Gradio app ──
with gr.Blocks(title="GFP Mutation Explorer", css="footer{display:none!important}") as demo:
gr.HTML(html_content)
demo.launch()
|