MetaLATTE-demo / app.py
yinuozhang's picture
Update app.py
440d372 verified
raw
history blame
4.1 kB
# ---- BOOTSTRAP: stable cache to /data, minimal downloads ----
import os, subprocess
from huggingface_hub import snapshot_download
os.makedirs("/data/.cache/huggingface/hub", exist_ok=True)
os.makedirs("/data/snapshots", exist_ok=True)
os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
# Optional: keep pip cache small
try:
subprocess.run(["pip", "cache", "purge"], check=False)
except Exception:
pass
# ---- END BOOTSTRAP ----
import gradio as gr
import sys
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
# Pin via Space → Settings → Variables if you want (helps avoid repeated downloads)
MODEL_ID = "ChatterjeeLab/MetaLATTE"
TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D"
MODEL_REV = os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401") # from your screenshot
TOKENIZER_REV = os.getenv("TOKENIZER_REV", "")
def snapshot_to(local_name, repo_id, revision, allow_patterns):
local_dir = f"/data/snapshots/{local_name}"
os.makedirs(local_dir, exist_ok=True)
return snapshot_download(
repo_id=repo_id,
revision=revision if revision else None,
allow_patterns=allow_patterns,
local_dir=local_dir, # new hub ignores symlink flag; this is enough
)
# Download tokenizer files (small)
esm_local = snapshot_to(
"esm2_tokenizer",
TOKENIZER_ID,
TOKENIZER_REV,
allow_patterns=[
"tokenizer.json","tokenizer_config.json","vocab.*","merges.*",
"special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken",
"config.json" # some tokenizers use it
],
)
# Download MetaLATTE weights + config ONLY (skip stage1 blob)
metalatte_local = snapshot_to(
"metalatte_model",
MODEL_ID,
MODEL_REV,
allow_patterns=["config.json", "pytorch_model.bin"],
)
# Your local custom code
metalatte_path = '.'
sys.path.insert(0, metalatte_path)
from configuration import MetaLATTEConfig
from modeling_metalatte import MultitaskProteinModel
AutoConfig.register("metalatte", MetaLATTEConfig)
AutoModel.register(MetaLATTEConfig, MultitaskProteinModel)
# Load config + instantiate model (no network)
config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True)
# Find the weight file locally
weight_candidates = [
"pytorch_model.bin",
"model/pytorch_model.bin",
"model.safetensors",
"model/model.safetensors",
"stage1_model.bin",
"model/stage1_model.bin",
]
weight_path = None
for c in weight_candidates:
p = os.path.join(metalatte_local, c)
if os.path.exists(p):
weight_path = p
break
if weight_path is None:
raise FileNotFoundError(f"No weights found in {metalatte_local}. Looked for: {weight_candidates}")
# Build model and load the local state dict
model = MultitaskProteinModel(config)
if weight_path.endswith(".safetensors"):
from safetensors.torch import load_file
state_dict = load_file(weight_path, device="cpu")
else:
state_dict = torch.load(weight_path, map_location="cpu")
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}")
model.eval()
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True)
@torch.inference_mode()
def predict(sequence):
inputs = tokenizer(sequence, return_tensors="pt")
raw_probs, predictions = model.predict(**inputs)
id2label = config.id2label
row = {id2label[i]: ('✓' if int(pred) == 1 else '') for i, pred in enumerate(predictions[0])}
return pd.DataFrame([row])
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence here..."),
outputs=gr.Dataframe(headers=list(config.id2label.values())),
title="MetaLATTE: Metal Binding Prediction",
description="Enter a protein sequence to predict its metal binding properties."
)
iface.launch()