File size: 2,597 Bytes
b1231eb
06617ae
d0e50a8
06617ae
 
 
b1231eb
7abf9a4
 
915c127
06617ae
 
915c127
06617ae
 
d0e50a8
06617ae
 
 
 
d0e50a8
 
06617ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
915c127
06617ae
6fa5243
06617ae
 
41e103f
06617ae
 
 
d0e50a8
06617ae
 
 
d6212ed
b1231eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from transformers import AutoModel
import torch
import json
import os
import time

os.environ["HF_HOME"] = "/tmp/hf"
os.environ["HF_HUB_CACHE"] = "/tmp/hf/hub"

SEQ = "SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
AAS = "ACDEFGHIKLMNPQRSTVWY"

ORACLE_UUID = "HaUuRwfE"    # Finetuned / Local / GFP / 1D (70k examples)
DESIGN_UUID = "YoQkzoLD"    # GFP DESIGN / Finetuned / Local / GFP / 1D (64 examples)

def compute_all_scores(metl, uuid, label=""):
    """Compute scores for all single mutations."""
    print(f"Loading model {uuid} ({label})...")
    metl.load_from_uuid(uuid)
    metl.eval()

    scores = []
    t0 = time.time()
    for pos in range(len(SEQ)):
        row = []
        for ai, aa in enumerate(AAS):
            if aa == SEQ[pos]:
                row.append(None)
                continue
            variant = [f"{SEQ[pos]}{pos}{aa}"]
            encoded = metl.encoder.encode_variants(SEQ, variant)
            with torch.no_grad():
                pred = metl(torch.tensor(encoded))
            score = pred.item() if pred.numel() == 1 else pred[0].item()
            row.append(round(score, 4))
        scores.append(row)
        if (pos + 1) % 50 == 0:
            elapsed = time.time() - t0
            print(f"  {label}: {pos+1}/{len(SEQ)} positions ({elapsed:.1f}s)")

    print(f"  {label}: Done in {time.time()-t0:.1f}s")
    return scores


def build_html(oracle_scores, design_scores):
    """Read the frontend template and inject real scores."""
    with open("frontend.html", "r") as f:
        html = f.read()

    html = html.replace("/*__ORACLE_SCORES__*/", json.dumps(oracle_scores))
    html = html.replace("/*__DESIGN_SCORES__*/", json.dumps(design_scores))
    return html


# ── Startup: load models and pre-compute ──
print("=" * 60)
print("GFP Mutation Explorer - Pre-computing scores...")
print("=" * 60)

metl = AutoModel.from_pretrained('gitter-lab/METL', trust_remote_code=True)

oracle_scores = compute_all_scores(metl, ORACLE_UUID, "Oracle (70k)")
design_scores = compute_all_scores(metl, DESIGN_UUID, "GFP-Design (64)")

print("\nBuilding frontend...")
html_content = build_html(oracle_scores, design_scores)
print(f"Frontend ready ({len(html_content)} chars)")

# ── Gradio app ──
with gr.Blocks(title="GFP Mutation Explorer", css="footer{display:none!important}") as demo:
    gr.HTML(html_content)

demo.launch()