File size: 3,826 Bytes
660dc20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85aeb06
bcb033b
a7c204b
85aeb06
660dc20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b69191
4831657
2b69191
 
 
660dc20
4831657
4c036b8
85aeb06
660dc20
 
 
 
 
85aeb06
 
 
7c0621c
85aeb06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# ---- BOOTSTRAP: keep storage under control on Spaces ----
import os, shutil, subprocess
from huggingface_hub import scan_cache_dir, snapshot_download

# 1) Put ALL caches in /data so they’re manageable & persistent
os.makedirs("/data/.cache", exist_ok=True)
os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache/huggingface/transformers")
os.environ.setdefault("DATASETS_CACHE", "/data/.cache/huggingface/datasets")

# 2) Prune old HF cache revisions (keeps current blobs, deletes stale revs)
try:
    cache = scan_cache_dir(os.environ["HF_HUB_CACHE"])
    cache.delete_revisions([rev for rev in cache.revisions])
except Exception as e:
    print(f"[cache prune] skipped: {e}")

# (Optional) light guard: trim pip wheel cache
try:
    subprocess.run(["pip", "cache", "purge"], check=False)
except Exception:
    pass
# ---- END BOOTSTRAP ----

import gradio as gr
import sys
import pandas as pd
from transformers import AutoTokenizer, AutoModel, AutoConfig

# If you want fully reproducible rebuilds, set these in Space → Settings → Variables
# (or leave blank to use latest)
MODEL_ID = "ChatterjeeLab/MetaLATTE"
TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D"
MODEL_REV = os.getenv("MODEL_REV", "")         # e.g. "a1b2c3d"
TOKENIZER_REV = os.getenv("TOKENIZER_REV", "") # e.g. "9f8e7d6"

# Prefer downloading *exactly* what you need to /data and load locally.
# This avoids multiple revision copies over time.
def maybe_snapshot(repo_id, revision, allow_patterns):
    kw = dict(repo_id=repo_id, local_dir=None, ignore_regex=None)
    if revision:
        kw["revision"] = revision
    # Download to HF cache in /data; return the resolved local dir
    return snapshot_download(allow_patterns=allow_patterns, **kw)

# Download tokenizer files only (small)
esm_local = maybe_snapshot(
    TOKENIZER_ID, TOKENIZER_REV,
    allow_patterns=[
        "tokenizer.json","tokenizer_config.json","vocab.*","merges.*",
        "special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken"
    ]
)

# Download MetaLATTE (weights + config only)
metalatte_local = maybe_snapshot(
    MODEL_ID, MODEL_REV,
    allow_patterns=["*.json","*.safetensors","*.bin","*.model","*.txt"]  # keep it tight
)

# Add the current directory to the system path for your custom code
metalatte_path = '.'
sys.path.insert(0, metalatte_path)

# Import the custom configuration and model
from configuration import MetaLATTEConfig
from modeling_metalatte import MultitaskProteinModel
AutoConfig.register("metalatte", MetaLATTEConfig)
AutoModel.register(MetaLATTEConfig, MultitaskProteinModel)

# Load from the local snapshot dirs (avoids re-downloading on rebuilds)
tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True)
config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True)
model = AutoModel.from_pretrained(metalatte_local, config=config, local_files_only=True)


def predict(sequence):
    inputs = tokenizer(sequence, return_tensors="pt")
    raw_probs, predictions = model.predict(**inputs)
    
    id2label = config.id2label
    results = {}
    for i, pred in enumerate(predictions[0]):
        metal = id2label[i]
        probability = raw_probs[0][i].item()
        results[metal] = '✓' if pred == 1 else ''
    
    df = pd.DataFrame([results])
    return df

iface = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence here..."),
    outputs=gr.Dataframe(headers=list(config.id2label.values())),
    title="MetaLATTE: Metal Binding Prediction",
    description="Enter a protein sequence to predict its metal binding properties."
)

iface.launch()