File size: 6,188 Bytes
40e900b
 
 
660dc20
6a51705
 
660dc20
 
 
 
40e900b
660dc20
 
 
 
 
 
85aeb06
bcb033b
a7c204b
40e900b
440d372
660dc20
40e900b
660dc20
 
40e900b
 
660dc20
6a51705
 
 
 
 
 
 
40e900b
6a51705
 
fa4c075
6a51705
fa4c075
660dc20
 
fa4c075
6a51705
660dc20
 
fa4c075
6a51705
fa4c075
 
 
 
 
 
 
 
 
 
660dc20
 
fa4c075
 
 
 
 
2b69191
660dc20
fa4c075
 
e161171
 
660dc20
fa4c075
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e900b
fa4c075
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e900b
fa4c075
 
85aeb06
40e900b
85aeb06
 
7c0621c
85aeb06
40e900b
 
85aeb06
 
 
 
 
 
 
 
40e900b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# ---- BOOTSTRAP: stable cache to /data, minimal downloads ----
import os, subprocess
from huggingface_hub import snapshot_download

os.makedirs("/data/.cache/huggingface/hub", exist_ok=True)
os.makedirs("/data/snapshots", exist_ok=True)
os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")

# Optional: keep pip cache small
try:
    subprocess.run(["pip", "cache", "purge"], check=False)
except Exception:
    pass
# ---- END BOOTSTRAP ----

import gradio as gr
import sys
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig

# Pin via Space → Settings → Variables if you want (helps avoid repeated downloads)
MODEL_ID = "ChatterjeeLab/MetaLATTE"
TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D"
MODEL_REV = os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401")  # from your screenshot
TOKENIZER_REV = os.getenv("TOKENIZER_REV", "")

def snapshot_to(local_name, repo_id, revision, allow_patterns):
    local_dir = f"/data/snapshots/{local_name}"
    os.makedirs(local_dir, exist_ok=True)
    return snapshot_download(
        repo_id=repo_id,
        revision=revision if revision else None,
        allow_patterns=allow_patterns,
        local_dir=local_dir,  # new hub ignores symlink flag; this is enough
    )

# Download tokenizer (unchanged)
esm_local = snapshot_to(
    "esm2_tokenizer", "facebook/esm2_t33_650M_UR50D", os.getenv("TOKENIZER_REV",""),
    allow_patterns=[
        "tokenizer.json","tokenizer_config.json","vocab.*","merges.*",
        "special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken","config.json"
    ],
)

# Download MetaLATTE: include both main and stage1 in case your loader uses them
metalatte_local = snapshot_to(
    "metalatte_model", "ChatterjeeLab/MetaLATTE", os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401"),
    allow_patterns=[
        "config.json",
        "pytorch_model.bin",
        "model/pytorch_model.bin",
        "model.safetensors",
        "model/model.safetensors",
        "stage1_model.bin",
        "model/stage1_model.bin",
    ],
)

import os, sys, torch, pandas as pd, gradio as gr
from transformers import AutoTokenizer, AutoModel, AutoConfig

# --- your local package ---
sys.path.insert(0, ".")
from configuration import MetaLATTEConfig
from modeling_metalatte import MultitaskProteinModel

# Register types BEFORE loading
AutoConfig.register("metalatte", MetaLATTEConfig)
AutoModel.register(MetaLATTEConfig, MultitaskProteinModel)

# ---- Monkey-patch: make your from_pretrained support local dirs ----
def _local_aware_from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
    import os
    from transformers import AutoConfig
    from safetensors.torch import load_file as load_safetensors

    # If a local directory is passed, load directly from disk
    if os.path.isdir(pretrained_model_name_or_path):
        config = kwargs.get("config", None)
        if config is None:
            try:
                # works because we registered the type above
                config = AutoConfig.from_pretrained(pretrained_model_name_or_path, local_files_only=True)
            except Exception:
                # fallback in case AutoConfig isn't enough
                config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path, local_files_only=True)

        model = cls(config)

        # Look for weights in common locations; prefer .safetensors > pytorch .bin > stage1
        candidates = [
            "model/model.safetensors", "model.safetensors",
            "model/pytorch_model.bin", "pytorch_model.bin",
            "model/stage1_model.bin", "stage1_model.bin",
        ]
        weight_path = next((os.path.join(pretrained_model_name_or_path, c) for c in candidates if os.path.exists(os.path.join(pretrained_model_name_or_path, c))), None)
        if weight_path is None:
            raise FileNotFoundError(f"No weights found in {pretrained_model_name_or_path}; tried {candidates}")

        # Load state dict (STRICT to catch any mismatch instead of silently skipping)
        if weight_path.endswith(".safetensors"):
            state = load_safetensors(weight_path, device="cpu")
        else:
            state = torch.load(weight_path, map_location="cpu")

        missing, unexpected = model.load_state_dict(state, strict=True)
        if missing or unexpected:
            raise RuntimeError(f"State dict mismatch. missing={missing[:5]}... unexpected={unexpected[:5]}...")
        model.eval()
        return model

    # Otherwise, fall back to the original remote/HF logic (your class already had)
    # NOTE: We call the original classmethod via the unbound function on the class
    return _orig_from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

# Swap in the monkey patch (but keep a handle to the original)
_orig_from_pretrained = MultitaskProteinModel.from_pretrained.__func__
MultitaskProteinModel.from_pretrained = classmethod(_local_aware_from_pretrained)
# --------------------------------------------------------------------

# Load config and model exactly like before (now it will use the local-aware loader)
config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True)
model = AutoModel.from_pretrained(metalatte_local, config=config, local_files_only=True)
model.eval()

@torch.inference_mode()
def predict(sequence):
    inputs = tokenizer(sequence, return_tensors="pt")
    raw_probs, predictions = model.predict(**inputs)
    id2label = config.id2label
    row = {id2label[i]: ('✓' if int(pred) == 1 else '') for i, pred in enumerate(predictions[0])}
    return pd.DataFrame([row])

iface = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence here..."),
    outputs=gr.Dataframe(headers=list(config.id2label.values())),
    title="MetaLATTE: Metal Binding Prediction",
    description="Enter a protein sequence to predict its metal binding properties."
)
iface.launch()