Spaces:
Runtime error
Runtime error
| # ---- BOOTSTRAP: stable cache to /data, minimal downloads ---- | |
| import os, subprocess | |
| from huggingface_hub import snapshot_download | |
| os.makedirs("/data/.cache/huggingface/hub", exist_ok=True) | |
| os.makedirs("/data/snapshots", exist_ok=True) | |
| os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache") | |
| os.environ.setdefault("HF_HOME", "/data/.cache/huggingface") | |
| os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub") | |
| # Optional: keep pip cache small | |
| try: | |
| subprocess.run(["pip", "cache", "purge"], check=False) | |
| except Exception: | |
| pass | |
| # ---- END BOOTSTRAP ---- | |
| import gradio as gr | |
| import sys | |
| import pandas as pd | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel, AutoConfig | |
| # Pin via Space → Settings → Variables if you want (helps avoid repeated downloads) | |
| MODEL_ID = "ChatterjeeLab/MetaLATTE" | |
| TOKENIZER_ID = "facebook/esm2_t33_650M_UR50D" | |
| MODEL_REV = os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401") # from your screenshot | |
| TOKENIZER_REV = os.getenv("TOKENIZER_REV", "") | |
| def snapshot_to(local_name, repo_id, revision, allow_patterns): | |
| local_dir = f"/data/snapshots/{local_name}" | |
| os.makedirs(local_dir, exist_ok=True) | |
| return snapshot_download( | |
| repo_id=repo_id, | |
| revision=revision if revision else None, | |
| allow_patterns=allow_patterns, | |
| local_dir=local_dir, # new hub ignores symlink flag; this is enough | |
| ) | |
| # Download tokenizer (unchanged) | |
| esm_local = snapshot_to( | |
| "esm2_tokenizer", "facebook/esm2_t33_650M_UR50D", os.getenv("TOKENIZER_REV",""), | |
| allow_patterns=[ | |
| "tokenizer.json","tokenizer_config.json","vocab.*","merges.*", | |
| "special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken","config.json" | |
| ], | |
| ) | |
| # Download MetaLATTE: include both main and stage1 in case your loader uses them | |
| metalatte_local = snapshot_to( | |
| "metalatte_model", "ChatterjeeLab/MetaLATTE", os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401"), | |
| allow_patterns=[ | |
| "config.json", | |
| "pytorch_model.bin", | |
| "model/pytorch_model.bin", | |
| "model.safetensors", | |
| "model/model.safetensors", | |
| "stage1_model.bin", | |
| "model/stage1_model.bin", | |
| ], | |
| ) | |
| import os, sys, torch, pandas as pd, gradio as gr | |
| from transformers import AutoTokenizer, AutoModel, AutoConfig | |
| # --- your local package --- | |
| sys.path.insert(0, ".") | |
| from configuration import MetaLATTEConfig | |
| from modeling_metalatte import MultitaskProteinModel | |
| # Register types BEFORE loading | |
| AutoConfig.register("metalatte", MetaLATTEConfig) | |
| AutoModel.register(MetaLATTEConfig, MultitaskProteinModel) | |
| # ---- Monkey-patch: make your from_pretrained support local dirs ---- | |
| def _local_aware_from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | |
| import os | |
| from transformers import AutoConfig | |
| from safetensors.torch import load_file as load_safetensors | |
| # If a local directory is passed, load directly from disk | |
| if os.path.isdir(pretrained_model_name_or_path): | |
| config = kwargs.get("config", None) | |
| if config is None: | |
| try: | |
| # works because we registered the type above | |
| config = AutoConfig.from_pretrained(pretrained_model_name_or_path, local_files_only=True) | |
| except Exception: | |
| # fallback in case AutoConfig isn't enough | |
| config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path, local_files_only=True) | |
| model = cls(config) | |
| # Look for weights in common locations; prefer .safetensors > pytorch .bin > stage1 | |
| candidates = [ | |
| "model/model.safetensors", "model.safetensors", | |
| "model/pytorch_model.bin", "pytorch_model.bin", | |
| "model/stage1_model.bin", "stage1_model.bin", | |
| ] | |
| weight_path = next((os.path.join(pretrained_model_name_or_path, c) for c in candidates if os.path.exists(os.path.join(pretrained_model_name_or_path, c))), None) | |
| if weight_path is None: | |
| raise FileNotFoundError(f"No weights found in {pretrained_model_name_or_path}; tried {candidates}") | |
| # Load state dict (STRICT to catch any mismatch instead of silently skipping) | |
| if weight_path.endswith(".safetensors"): | |
| state = load_safetensors(weight_path, device="cpu") | |
| else: | |
| state = torch.load(weight_path, map_location="cpu") | |
| missing, unexpected = model.load_state_dict(state, strict=True) | |
| if missing or unexpected: | |
| raise RuntimeError(f"State dict mismatch. missing={missing[:5]}... unexpected={unexpected[:5]}...") | |
| model.eval() | |
| return model | |
| # Otherwise, fall back to the original remote/HF logic (your class already had) | |
| # NOTE: We call the original classmethod via the unbound function on the class | |
| return _orig_from_pretrained(pretrained_model_name_or_path, *args, **kwargs) | |
| # Swap in the monkey patch (but keep a handle to the original) | |
| _orig_from_pretrained = MultitaskProteinModel.from_pretrained.__func__ | |
| MultitaskProteinModel.from_pretrained = classmethod(_local_aware_from_pretrained) | |
| # -------------------------------------------------------------------- | |
| # Load config and model exactly like before (now it will use the local-aware loader) | |
| config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True) | |
| tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True) | |
| model = AutoModel.from_pretrained(metalatte_local, config=config, local_files_only=True) | |
| model.eval() | |
| def predict(sequence): | |
| inputs = tokenizer(sequence, return_tensors="pt") | |
| raw_probs, predictions = model.predict(**inputs) | |
| id2label = config.id2label | |
| row = {id2label[i]: ('✓' if int(pred) == 1 else '') for i, pred in enumerate(predictions[0])} | |
| return pd.DataFrame([row]) | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence here..."), | |
| outputs=gr.Dataframe(headers=list(config.id2label.values())), | |
| title="MetaLATTE: Metal Binding Prediction", | |
| description="Enter a protein sequence to predict its metal binding properties." | |
| ) | |
| iface.launch() | |