File size: 4,097 Bytes
40e900b
 
 
660dc20
6a51705
 
660dc20
 
 
 
40e900b
660dc20
 
 
 
 
 
85aeb06
bcb033b
a7c204b
40e900b
440d372
660dc20
40e900b
660dc20
 
40e900b
 
660dc20
6a51705
 
 
 
 
 
 
40e900b
6a51705
 
40e900b
6a51705
 
 
 
660dc20
 
40e900b
 
6a51705
660dc20
 
40e900b
6a51705
 
 
 
40e900b
660dc20
 
40e900b
2b69191
4831657
2b69191
660dc20
e161171
 
40e900b
660dc20
 
40e900b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85aeb06
40e900b
85aeb06
 
7c0621c
85aeb06
40e900b
 
85aeb06
 
 
 
 
 
 
 
40e900b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# ---- BOOTSTRAP: stable cache to /data, minimal downloads ----
import os, subprocess
from huggingface_hub import snapshot_download

os.makedirs("/data/.cache/huggingface/hub", exist_ok=True)
os.makedirs("/data/snapshots", exist_ok=True)
os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")

# Optional: keep pip cache small
try:
    subprocess.run(["pip", "cache", "purge"], check=False)
except Exception:
    pass
# ---- END BOOTSTRAP ----

import gradio as gr
import sys
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig

# Pin via Space → Settings → Variables if you want (helps avoid repeated downloads)
MODEL_ID = "ChatterjeeLab/MetaLATTE"
TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D"
MODEL_REV = os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401")  # from your screenshot
TOKENIZER_REV = os.getenv("TOKENIZER_REV", "")

def snapshot_to(local_name, repo_id, revision, allow_patterns):
    local_dir = f"/data/snapshots/{local_name}"
    os.makedirs(local_dir, exist_ok=True)
    return snapshot_download(
        repo_id=repo_id,
        revision=revision if revision else None,
        allow_patterns=allow_patterns,
        local_dir=local_dir,  # new hub ignores symlink flag; this is enough
    )

# Download tokenizer files (small)
esm_local = snapshot_to(
    "esm2_tokenizer",
    TOKENIZER_ID,
    TOKENIZER_REV,
    allow_patterns=[
        "tokenizer.json","tokenizer_config.json","vocab.*","merges.*",
        "special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken",
        "config.json"  # some tokenizers use it
    ],
)

# Download MetaLATTE weights + config ONLY (skip stage1 blob)
metalatte_local = snapshot_to(
    "metalatte_model",
    MODEL_ID,
    MODEL_REV,
    allow_patterns=["config.json", "pytorch_model.bin"],
)

# Your local custom code
metalatte_path = '.'
sys.path.insert(0, metalatte_path)
from configuration import MetaLATTEConfig
from modeling_metalatte import MultitaskProteinModel
AutoConfig.register("metalatte", MetaLATTEConfig)
AutoModel.register(MetaLATTEConfig, MultitaskProteinModel)
# Load config + instantiate model (no network)
config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True)

# Find the weight file locally
weight_candidates = [
    "pytorch_model.bin",
    "model/pytorch_model.bin",
    "model.safetensors",
    "model/model.safetensors",
    "stage1_model.bin",
    "model/stage1_model.bin",
]
weight_path = None
for c in weight_candidates:
    p = os.path.join(metalatte_local, c)
    if os.path.exists(p):
        weight_path = p
        break
if weight_path is None:
    raise FileNotFoundError(f"No weights found in {metalatte_local}. Looked for: {weight_candidates}")

# Build model and load the local state dict
model = MultitaskProteinModel(config)
if weight_path.endswith(".safetensors"):
    from safetensors.torch import load_file
    state_dict = load_file(weight_path, device="cpu")
else:
    state_dict = torch.load(weight_path, map_location="cpu")
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
    print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}")
model.eval()

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True)

@torch.inference_mode()
def predict(sequence):
    inputs = tokenizer(sequence, return_tensors="pt")
    raw_probs, predictions = model.predict(**inputs)
    id2label = config.id2label
    row = {id2label[i]: ('✓' if int(pred) == 1 else '') for i, pred in enumerate(predictions[0])}
    return pd.DataFrame([row])

iface = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence here..."),
    outputs=gr.Dataframe(headers=list(config.id2label.values())),
    title="MetaLATTE: Metal Binding Prediction",
    description="Enter a protein sequence to predict its metal binding properties."
)
iface.launch()