File size: 3,844 Bytes
660dc20
 
 
 
6a51705
 
 
 
660dc20
 
 
6a51705
 
660dc20
6a51705
660dc20
 
6a51705
 
660dc20
 
 
6a51705
660dc20
 
 
 
 
 
85aeb06
bcb033b
a7c204b
85aeb06
660dc20
6a51705
660dc20
 
 
 
 
6a51705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660dc20
 
 
6a51705
660dc20
 
6a51705
 
 
 
 
 
660dc20
 
6a51705
2b69191
4831657
2b69191
 
660dc20
4831657
4c036b8
85aeb06
6a51705
660dc20
 
 
 
85aeb06
 
 
7c0621c
85aeb06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# ---- BOOTSTRAP: keep storage under control on Spaces ----
import os, shutil, subprocess
from huggingface_hub import scan_cache_dir, snapshot_download

# Put caches in /data and make sure dirs exist
os.makedirs("/data/.cache/huggingface/hub", exist_ok=True)
os.makedirs("/data/snapshots", exist_ok=True)

os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
# Avoid TRANSFORMERS_CACHE deprecation; HF_HOME is enough.
# os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache/huggingface/transformers")

# Prune old HF cache revisions (safe if empty; now the dir exists)
try:
    cache = scan_cache_dir(os.environ["HF_HUB_CACHE"])
    if cache.revisions:
        cache.delete_revisions([rev for rev in cache.revisions])
except Exception as e:
    print(f"[cache prune] skipped: {e}")

# Light pip cache cleanup
try:
    subprocess.run(["pip", "cache", "purge"], check=False)
except Exception:
    pass
# ---- END BOOTSTRAP ----

import gradio as gr
import sys
import pandas as pd
from transformers import AutoTokenizer, AutoModel, AutoConfig

# Optional: pin commits via Space Variables
MODEL_ID = "ChatterjeeLab/MetaLATTE"
TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D"
MODEL_REV = os.getenv("MODEL_REV", "")         # e.g. "a1b2c3d"
TOKENIZER_REV = os.getenv("TOKENIZER_REV", "") # e.g. "9f8e7d6"

def snapshot_to(local_name, repo_id, revision, allow_patterns):
    """Download only needed files into a concrete folder under /data/snapshots."""
    local_dir = f"/data/snapshots/{local_name}"
    os.makedirs(local_dir, exist_ok=True)
    # IMPORTANT: no ignore_regex; use ignore_patterns if needed
    return snapshot_download(
        repo_id=repo_id,
        revision=revision if revision else None,
        allow_patterns=allow_patterns,
        local_dir=local_dir,
        local_dir_use_symlinks=False,   # copy files into local_dir; easier to manage size
    )

# Tokenizer (small set of files)
esm_local = snapshot_to(
    "esm2_tokenizer",
    TOKENIZER_ID,
    TOKENIZER_REV,
    allow_patterns=[
        "tokenizer.json","tokenizer_config.json","vocab.*","merges.*",
        "special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken"
    ],
)

# MetaLATTE model (weights + config only)
metalatte_local = snapshot_to(
    "metalatte_model",
    MODEL_ID,
    MODEL_REV,
    allow_patterns=["*.json","*.safetensors","*.bin","*.model","*.txt"],
)

# Your local package
metalatte_path = '.'
sys.path.insert(0, metalatte_path)

from configuration import MetaLATTEConfig
from modeling_metalatte import MultitaskProteinModel
AutoConfig.register("metalatte", MetaLATTEConfig)
AutoModel.register(MetaLATTEConfig, MultitaskProteinModel)

# Load from the downloaded dirs (no network, no extra cache growth)
tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True)
config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True)
model = AutoModel.from_pretrained(metalatte_local, config=config, local_files_only=True)


def predict(sequence):
    inputs = tokenizer(sequence, return_tensors="pt")
    raw_probs, predictions = model.predict(**inputs)
    
    id2label = config.id2label
    results = {}
    for i, pred in enumerate(predictions[0]):
        metal = id2label[i]
        probability = raw_probs[0][i].item()
        results[metal] = '✓' if pred == 1 else ''
    
    df = pd.DataFrame([results])
    return df

iface = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence here..."),
    outputs=gr.Dataframe(headers=list(config.id2label.values())),
    title="MetaLATTE: Metal Binding Prediction",
    description="Enter a protein sequence to predict its metal binding properties."
)

iface.launch()