MetaLATTE-demo / app.py
yinuozhang's picture
Update app.py
fa4c075 verified
# ---- BOOTSTRAP: stable cache to /data, minimal downloads ----
import os, subprocess
from huggingface_hub import snapshot_download
os.makedirs("/data/.cache/huggingface/hub", exist_ok=True)
os.makedirs("/data/snapshots", exist_ok=True)
os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
# Optional: keep pip cache small
try:
subprocess.run(["pip", "cache", "purge"], check=False)
except Exception:
pass
# ---- END BOOTSTRAP ----
import gradio as gr
import sys
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
# Pin via Space → Settings → Variables if you want (helps avoid repeated downloads)
MODEL_ID = "ChatterjeeLab/MetaLATTE"
TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D"
MODEL_REV = os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401") # from your screenshot
TOKENIZER_REV = os.getenv("TOKENIZER_REV", "")
def snapshot_to(local_name, repo_id, revision, allow_patterns):
local_dir = f"/data/snapshots/{local_name}"
os.makedirs(local_dir, exist_ok=True)
return snapshot_download(
repo_id=repo_id,
revision=revision if revision else None,
allow_patterns=allow_patterns,
local_dir=local_dir, # new hub ignores symlink flag; this is enough
)
# Download tokenizer (unchanged)
esm_local = snapshot_to(
"esm2_tokenizer", "facebook/esm2_t33_650M_UR50D", os.getenv("TOKENIZER_REV",""),
allow_patterns=[
"tokenizer.json","tokenizer_config.json","vocab.*","merges.*",
"special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken","config.json"
],
)
# Download MetaLATTE: include both main and stage1 in case your loader uses them
metalatte_local = snapshot_to(
"metalatte_model", "ChatterjeeLab/MetaLATTE", os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401"),
allow_patterns=[
"config.json",
"pytorch_model.bin",
"model/pytorch_model.bin",
"model.safetensors",
"model/model.safetensors",
"stage1_model.bin",
"model/stage1_model.bin",
],
)
import os, sys, torch, pandas as pd, gradio as gr
from transformers import AutoTokenizer, AutoModel, AutoConfig
# --- your local package ---
sys.path.insert(0, ".")
from configuration import MetaLATTEConfig
from modeling_metalatte import MultitaskProteinModel
# Register types BEFORE loading
AutoConfig.register("metalatte", MetaLATTEConfig)
AutoModel.register(MetaLATTEConfig, MultitaskProteinModel)
# ---- Monkey-patch: make your from_pretrained support local dirs ----
def _local_aware_from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
import os
from transformers import AutoConfig
from safetensors.torch import load_file as load_safetensors
# If a local directory is passed, load directly from disk
if os.path.isdir(pretrained_model_name_or_path):
config = kwargs.get("config", None)
if config is None:
try:
# works because we registered the type above
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, local_files_only=True)
except Exception:
# fallback in case AutoConfig isn't enough
config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path, local_files_only=True)
model = cls(config)
# Look for weights in common locations; prefer .safetensors > pytorch .bin > stage1
candidates = [
"model/model.safetensors", "model.safetensors",
"model/pytorch_model.bin", "pytorch_model.bin",
"model/stage1_model.bin", "stage1_model.bin",
]
weight_path = next((os.path.join(pretrained_model_name_or_path, c) for c in candidates if os.path.exists(os.path.join(pretrained_model_name_or_path, c))), None)
if weight_path is None:
raise FileNotFoundError(f"No weights found in {pretrained_model_name_or_path}; tried {candidates}")
# Load state dict (STRICT to catch any mismatch instead of silently skipping)
if weight_path.endswith(".safetensors"):
state = load_safetensors(weight_path, device="cpu")
else:
state = torch.load(weight_path, map_location="cpu")
missing, unexpected = model.load_state_dict(state, strict=True)
if missing or unexpected:
raise RuntimeError(f"State dict mismatch. missing={missing[:5]}... unexpected={unexpected[:5]}...")
model.eval()
return model
# Otherwise, fall back to the original remote/HF logic (your class already had)
# NOTE: We call the original classmethod via the unbound function on the class
return _orig_from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
# Swap in the monkey patch (but keep a handle to the original)
_orig_from_pretrained = MultitaskProteinModel.from_pretrained.__func__
MultitaskProteinModel.from_pretrained = classmethod(_local_aware_from_pretrained)
# --------------------------------------------------------------------
# Load config and model exactly like before (now it will use the local-aware loader)
config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True)
model = AutoModel.from_pretrained(metalatte_local, config=config, local_files_only=True)
model.eval()
@torch.inference_mode()
def predict(sequence):
inputs = tokenizer(sequence, return_tensors="pt")
raw_probs, predictions = model.predict(**inputs)
id2label = config.id2label
row = {id2label[i]: ('✓' if int(pred) == 1 else '') for i, pred in enumerate(predictions[0])}
return pd.DataFrame([row])
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence here..."),
outputs=gr.Dataframe(headers=list(config.id2label.values())),
title="MetaLATTE: Metal Binding Prediction",
description="Enter a protein sequence to predict its metal binding properties."
)
iface.launch()