Spaces:
Runtime error
Runtime error
File size: 4,097 Bytes
40e900b 660dc20 6a51705 660dc20 40e900b 660dc20 85aeb06 bcb033b a7c204b 40e900b 440d372 660dc20 40e900b 660dc20 40e900b 660dc20 6a51705 40e900b 6a51705 40e900b 6a51705 660dc20 40e900b 6a51705 660dc20 40e900b 6a51705 40e900b 660dc20 40e900b 2b69191 4831657 2b69191 660dc20 e161171 40e900b 660dc20 40e900b 85aeb06 40e900b 85aeb06 7c0621c 85aeb06 40e900b 85aeb06 40e900b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# ---- BOOTSTRAP: stable cache to /data, minimal downloads ----
import os, subprocess
from huggingface_hub import snapshot_download
os.makedirs("/data/.cache/huggingface/hub", exist_ok=True)
os.makedirs("/data/snapshots", exist_ok=True)
os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
# Optional: keep pip cache small
try:
subprocess.run(["pip", "cache", "purge"], check=False)
except Exception:
pass
# ---- END BOOTSTRAP ----
import gradio as gr
import sys
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
# Pin via Space → Settings → Variables if you want (helps avoid repeated downloads)
MODEL_ID = "ChatterjeeLab/MetaLATTE"
TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D"
MODEL_REV = os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401") # from your screenshot
TOKENIZER_REV = os.getenv("TOKENIZER_REV", "")
def snapshot_to(local_name, repo_id, revision, allow_patterns):
local_dir = f"/data/snapshots/{local_name}"
os.makedirs(local_dir, exist_ok=True)
return snapshot_download(
repo_id=repo_id,
revision=revision if revision else None,
allow_patterns=allow_patterns,
local_dir=local_dir, # new hub ignores symlink flag; this is enough
)
# Download tokenizer files (small)
esm_local = snapshot_to(
"esm2_tokenizer",
TOKENIZER_ID,
TOKENIZER_REV,
allow_patterns=[
"tokenizer.json","tokenizer_config.json","vocab.*","merges.*",
"special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken",
"config.json" # some tokenizers use it
],
)
# Download MetaLATTE weights + config ONLY (skip stage1 blob)
metalatte_local = snapshot_to(
"metalatte_model",
MODEL_ID,
MODEL_REV,
allow_patterns=["config.json", "pytorch_model.bin"],
)
# Your local custom code
metalatte_path = '.'
sys.path.insert(0, metalatte_path)
from configuration import MetaLATTEConfig
from modeling_metalatte import MultitaskProteinModel
AutoConfig.register("metalatte", MetaLATTEConfig)
AutoModel.register(MetaLATTEConfig, MultitaskProteinModel)
# Load config + instantiate model (no network)
config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True)
# Find the weight file locally
weight_candidates = [
"pytorch_model.bin",
"model/pytorch_model.bin",
"model.safetensors",
"model/model.safetensors",
"stage1_model.bin",
"model/stage1_model.bin",
]
weight_path = None
for c in weight_candidates:
p = os.path.join(metalatte_local, c)
if os.path.exists(p):
weight_path = p
break
if weight_path is None:
raise FileNotFoundError(f"No weights found in {metalatte_local}. Looked for: {weight_candidates}")
# Build model and load the local state dict
model = MultitaskProteinModel(config)
if weight_path.endswith(".safetensors"):
from safetensors.torch import load_file
state_dict = load_file(weight_path, device="cpu")
else:
state_dict = torch.load(weight_path, map_location="cpu")
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}")
model.eval()
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True)
@torch.inference_mode()
def predict(sequence):
inputs = tokenizer(sequence, return_tensors="pt")
raw_probs, predictions = model.predict(**inputs)
id2label = config.id2label
row = {id2label[i]: ('✓' if int(pred) == 1 else '') for i, pred in enumerate(predictions[0])}
return pd.DataFrame([row])
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence here..."),
outputs=gr.Dataframe(headers=list(config.id2label.values())),
title="MetaLATTE: Metal Binding Prediction",
description="Enter a protein sequence to predict its metal binding properties."
)
iface.launch()
|