# ---- BOOTSTRAP: stable cache to /data, minimal downloads ---- import os, subprocess from huggingface_hub import snapshot_download os.makedirs("/data/.cache/huggingface/hub", exist_ok=True) os.makedirs("/data/snapshots", exist_ok=True) os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache") os.environ.setdefault("HF_HOME", "/data/.cache/huggingface") os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub") # Optional: keep pip cache small try: subprocess.run(["pip", "cache", "purge"], check=False) except Exception: pass # ---- END BOOTSTRAP ---- import gradio as gr import sys import pandas as pd import torch from transformers import AutoTokenizer, AutoConfig # Pin via Space → Settings → Variables if you want (helps avoid repeated downloads) MODEL_ID = "ChatterjeeLab/MetaLATTE" TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D" MODEL_REV = os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401") # from your screenshot TOKENIZER_REV = os.getenv("TOKENIZER_REV", "") def snapshot_to(local_name, repo_id, revision, allow_patterns): local_dir = f"/data/snapshots/{local_name}" os.makedirs(local_dir, exist_ok=True) return snapshot_download( repo_id=repo_id, revision=revision if revision else None, allow_patterns=allow_patterns, local_dir=local_dir, # new hub ignores symlink flag; this is enough ) # Download tokenizer files (small) esm_local = snapshot_to( "esm2_tokenizer", TOKENIZER_ID, TOKENIZER_REV, allow_patterns=[ "tokenizer.json","tokenizer_config.json","vocab.*","merges.*", "special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken", "config.json" # some tokenizers use it ], ) # Download MetaLATTE weights + config ONLY (skip stage1 blob) metalatte_local = snapshot_to( "metalatte_model", MODEL_ID, MODEL_REV, allow_patterns=["config.json", "pytorch_model.bin"], ) # Your local custom code metalatte_path = '.' sys.path.insert(0, metalatte_path) from configuration import MetaLATTEConfig from modeling_metalatte import MultitaskProteinModel # Load config + instantiate model (no network) config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True) # Find the weight file locally weight_candidates = [ "pytorch_model.bin", "model/pytorch_model.bin", "model.safetensors", "model/model.safetensors", "stage1_model.bin", "model/stage1_model.bin", ] weight_path = None for c in weight_candidates: p = os.path.join(metalatte_local, c) if os.path.exists(p): weight_path = p break if weight_path is None: raise FileNotFoundError(f"No weights found in {metalatte_local}. Looked for: {weight_candidates}") # Build model and load the local state dict model = MultitaskProteinModel(config) if weight_path.endswith(".safetensors"): from safetensors.torch import load_file state_dict = load_file(weight_path, device="cpu") else: state_dict = torch.load(weight_path, map_location="cpu") missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing or unexpected: print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}") model.eval() # Tokenizer tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True) @torch.inference_mode() def predict(sequence): inputs = tokenizer(sequence, return_tensors="pt") raw_probs, predictions = model.predict(**inputs) id2label = config.id2label row = {id2label[i]: ('✓' if int(pred) == 1 else '') for i, pred in enumerate(predictions[0])} return pd.DataFrame([row]) iface = gr.Interface( fn=predict, inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence here..."), outputs=gr.Dataframe(headers=list(config.id2label.values())), title="MetaLATTE: Metal Binding Prediction", description="Enter a protein sequence to predict its metal binding properties." ) iface.launch()