Spaces:
Runtime error
Runtime error
File size: 6,188 Bytes
40e900b 660dc20 6a51705 660dc20 40e900b 660dc20 85aeb06 bcb033b a7c204b 40e900b 440d372 660dc20 40e900b 660dc20 40e900b 660dc20 6a51705 40e900b 6a51705 fa4c075 6a51705 fa4c075 660dc20 fa4c075 6a51705 660dc20 fa4c075 6a51705 fa4c075 660dc20 fa4c075 2b69191 660dc20 fa4c075 e161171 660dc20 fa4c075 40e900b fa4c075 40e900b fa4c075 85aeb06 40e900b 85aeb06 7c0621c 85aeb06 40e900b 85aeb06 40e900b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# ---- BOOTSTRAP: stable cache to /data, minimal downloads ----
import os, subprocess
from huggingface_hub import snapshot_download
os.makedirs("/data/.cache/huggingface/hub", exist_ok=True)
os.makedirs("/data/snapshots", exist_ok=True)
os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub")
# Optional: keep pip cache small
try:
subprocess.run(["pip", "cache", "purge"], check=False)
except Exception:
pass
# ---- END BOOTSTRAP ----
import gradio as gr
import sys
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
# Pin via Space → Settings → Variables if you want (helps avoid repeated downloads)
MODEL_ID = "ChatterjeeLab/MetaLATTE"
TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D"
MODEL_REV = os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401") # from your screenshot
TOKENIZER_REV = os.getenv("TOKENIZER_REV", "")
def snapshot_to(local_name, repo_id, revision, allow_patterns):
local_dir = f"/data/snapshots/{local_name}"
os.makedirs(local_dir, exist_ok=True)
return snapshot_download(
repo_id=repo_id,
revision=revision if revision else None,
allow_patterns=allow_patterns,
local_dir=local_dir, # new hub ignores symlink flag; this is enough
)
# Download tokenizer (unchanged)
esm_local = snapshot_to(
"esm2_tokenizer", "facebook/esm2_t33_650M_UR50D", os.getenv("TOKENIZER_REV",""),
allow_patterns=[
"tokenizer.json","tokenizer_config.json","vocab.*","merges.*",
"special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken","config.json"
],
)
# Download MetaLATTE: include both main and stage1 in case your loader uses them
metalatte_local = snapshot_to(
"metalatte_model", "ChatterjeeLab/MetaLATTE", os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401"),
allow_patterns=[
"config.json",
"pytorch_model.bin",
"model/pytorch_model.bin",
"model.safetensors",
"model/model.safetensors",
"stage1_model.bin",
"model/stage1_model.bin",
],
)
import os, sys, torch, pandas as pd, gradio as gr
from transformers import AutoTokenizer, AutoModel, AutoConfig
# --- your local package ---
sys.path.insert(0, ".")
from configuration import MetaLATTEConfig
from modeling_metalatte import MultitaskProteinModel
# Register types BEFORE loading
AutoConfig.register("metalatte", MetaLATTEConfig)
AutoModel.register(MetaLATTEConfig, MultitaskProteinModel)
# ---- Monkey-patch: make your from_pretrained support local dirs ----
def _local_aware_from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
import os
from transformers import AutoConfig
from safetensors.torch import load_file as load_safetensors
# If a local directory is passed, load directly from disk
if os.path.isdir(pretrained_model_name_or_path):
config = kwargs.get("config", None)
if config is None:
try:
# works because we registered the type above
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, local_files_only=True)
except Exception:
# fallback in case AutoConfig isn't enough
config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path, local_files_only=True)
model = cls(config)
# Look for weights in common locations; prefer .safetensors > pytorch .bin > stage1
candidates = [
"model/model.safetensors", "model.safetensors",
"model/pytorch_model.bin", "pytorch_model.bin",
"model/stage1_model.bin", "stage1_model.bin",
]
weight_path = next((os.path.join(pretrained_model_name_or_path, c) for c in candidates if os.path.exists(os.path.join(pretrained_model_name_or_path, c))), None)
if weight_path is None:
raise FileNotFoundError(f"No weights found in {pretrained_model_name_or_path}; tried {candidates}")
# Load state dict (STRICT to catch any mismatch instead of silently skipping)
if weight_path.endswith(".safetensors"):
state = load_safetensors(weight_path, device="cpu")
else:
state = torch.load(weight_path, map_location="cpu")
missing, unexpected = model.load_state_dict(state, strict=True)
if missing or unexpected:
raise RuntimeError(f"State dict mismatch. missing={missing[:5]}... unexpected={unexpected[:5]}...")
model.eval()
return model
# Otherwise, fall back to the original remote/HF logic (your class already had)
# NOTE: We call the original classmethod via the unbound function on the class
return _orig_from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
# Swap in the monkey patch (but keep a handle to the original)
_orig_from_pretrained = MultitaskProteinModel.from_pretrained.__func__
MultitaskProteinModel.from_pretrained = classmethod(_local_aware_from_pretrained)
# --------------------------------------------------------------------
# Load config and model exactly like before (now it will use the local-aware loader)
config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True)
model = AutoModel.from_pretrained(metalatte_local, config=config, local_files_only=True)
model.eval()
@torch.inference_mode()
def predict(sequence):
inputs = tokenizer(sequence, return_tensors="pt")
raw_probs, predictions = model.predict(**inputs)
id2label = config.id2label
row = {id2label[i]: ('✓' if int(pred) == 1 else '') for i, pred in enumerate(predictions[0])}
return pd.DataFrame([row])
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence here..."),
outputs=gr.Dataframe(headers=list(config.id2label.values())),
title="MetaLATTE: Metal Binding Prediction",
description="Enter a protein sequence to predict its metal binding properties."
)
iface.launch()
|