yinuozhang commited on
Commit
fa4c075
·
verified ·
1 Parent(s): 3008831

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -46
app.py CHANGED
@@ -37,68 +37,96 @@ def snapshot_to(local_name, repo_id, revision, allow_patterns):
37
  local_dir=local_dir, # new hub ignores symlink flag; this is enough
38
  )
39
 
40
- # Download tokenizer files (small)
41
  esm_local = snapshot_to(
42
- "esm2_tokenizer",
43
- TOKENIZER_ID,
44
- TOKENIZER_REV,
45
  allow_patterns=[
46
  "tokenizer.json","tokenizer_config.json","vocab.*","merges.*",
47
- "special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken",
48
- "config.json" # some tokenizers use it
49
  ],
50
  )
51
 
52
- # Download MetaLATTE weights + config ONLY (skip stage1 blob)
53
  metalatte_local = snapshot_to(
54
- "metalatte_model",
55
- MODEL_ID,
56
- MODEL_REV,
57
- allow_patterns=["config.json", "pytorch_model.bin"],
 
 
 
 
 
 
58
  )
59
 
60
- # Your local custom code
61
- metalatte_path = '.'
62
- sys.path.insert(0, metalatte_path)
 
 
63
  from configuration import MetaLATTEConfig
64
  from modeling_metalatte import MultitaskProteinModel
 
 
65
  AutoConfig.register("metalatte", MetaLATTEConfig)
66
  AutoModel.register(MetaLATTEConfig, MultitaskProteinModel)
67
- # Load config + instantiate model (no network)
68
- config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True)
69
 
70
- # Find the weight file locally
71
- weight_candidates = [
72
- "pytorch_model.bin",
73
- "model/pytorch_model.bin",
74
- "model.safetensors",
75
- "model/model.safetensors",
76
- "stage1_model.bin",
77
- "model/stage1_model.bin",
78
- ]
79
- weight_path = None
80
- for c in weight_candidates:
81
- p = os.path.join(metalatte_local, c)
82
- if os.path.exists(p):
83
- weight_path = p
84
- break
85
- if weight_path is None:
86
- raise FileNotFoundError(f"No weights found in {metalatte_local}. Looked for: {weight_candidates}")
87
-
88
- # Build model and load the local state dict
89
- model = MultitaskProteinModel(config)
90
- if weight_path.endswith(".safetensors"):
91
- from safetensors.torch import load_file
92
- state_dict = load_file(weight_path, device="cpu", weights_only=False)
93
- else:
94
- state_dict = torch.load(weight_path, map_location="cpu", weights_only=False)
95
- missing, unexpected = model.load_state_dict(state_dict, strict=False)
96
- if missing or unexpected:
97
- print(f"[load_state_dict] missing={len(missing)} unexpected={len(unexpected)}")
98
- model.eval()
99
 
100
- # Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True)
 
 
102
 
103
  @torch.inference_mode()
104
  def predict(sequence):
 
37
  local_dir=local_dir, # new hub ignores symlink flag; this is enough
38
  )
39
 
40
+ # Download tokenizer (unchanged)
41
  esm_local = snapshot_to(
42
+ "esm2_tokenizer", "facebook/esm2_t33_650M_UR50D", os.getenv("TOKENIZER_REV",""),
 
 
43
  allow_patterns=[
44
  "tokenizer.json","tokenizer_config.json","vocab.*","merges.*",
45
+ "special_tokens_map.json","*.model","tokenizer*.txt","spiece.*","*.tiktoken","config.json"
 
46
  ],
47
  )
48
 
49
+ # Download MetaLATTE: include both main and stage1 in case your loader uses them
50
  metalatte_local = snapshot_to(
51
+ "metalatte_model", "ChatterjeeLab/MetaLATTE", os.getenv("MODEL_REV", "ad1716045c768b30ce87eb6b3963d58578fa5401"),
52
+ allow_patterns=[
53
+ "config.json",
54
+ "pytorch_model.bin",
55
+ "model/pytorch_model.bin",
56
+ "model.safetensors",
57
+ "model/model.safetensors",
58
+ "stage1_model.bin",
59
+ "model/stage1_model.bin",
60
+ ],
61
  )
62
 
63
+ import os, sys, torch, pandas as pd, gradio as gr
64
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
65
+
66
+ # --- your local package ---
67
+ sys.path.insert(0, ".")
68
  from configuration import MetaLATTEConfig
69
  from modeling_metalatte import MultitaskProteinModel
70
+
71
+ # Register types BEFORE loading
72
  AutoConfig.register("metalatte", MetaLATTEConfig)
73
  AutoModel.register(MetaLATTEConfig, MultitaskProteinModel)
 
 
74
 
75
+ # ---- Monkey-patch: make your from_pretrained support local dirs ----
76
+ def _local_aware_from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
77
+ import os
78
+ from transformers import AutoConfig
79
+ from safetensors.torch import load_file as load_safetensors
80
+
81
+ # If a local directory is passed, load directly from disk
82
+ if os.path.isdir(pretrained_model_name_or_path):
83
+ config = kwargs.get("config", None)
84
+ if config is None:
85
+ try:
86
+ # works because we registered the type above
87
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, local_files_only=True)
88
+ except Exception:
89
+ # fallback in case AutoConfig isn't enough
90
+ config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path, local_files_only=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ model = cls(config)
93
+
94
+ # Look for weights in common locations; prefer .safetensors > pytorch .bin > stage1
95
+ candidates = [
96
+ "model/model.safetensors", "model.safetensors",
97
+ "model/pytorch_model.bin", "pytorch_model.bin",
98
+ "model/stage1_model.bin", "stage1_model.bin",
99
+ ]
100
+ weight_path = next((os.path.join(pretrained_model_name_or_path, c) for c in candidates if os.path.exists(os.path.join(pretrained_model_name_or_path, c))), None)
101
+ if weight_path is None:
102
+ raise FileNotFoundError(f"No weights found in {pretrained_model_name_or_path}; tried {candidates}")
103
+
104
+ # Load state dict (STRICT to catch any mismatch instead of silently skipping)
105
+ if weight_path.endswith(".safetensors"):
106
+ state = load_safetensors(weight_path, device="cpu")
107
+ else:
108
+ state = torch.load(weight_path, map_location="cpu")
109
+
110
+ missing, unexpected = model.load_state_dict(state, strict=True)
111
+ if missing or unexpected:
112
+ raise RuntimeError(f"State dict mismatch. missing={missing[:5]}... unexpected={unexpected[:5]}...")
113
+ model.eval()
114
+ return model
115
+
116
+ # Otherwise, fall back to the original remote/HF logic (your class already had)
117
+ # NOTE: We call the original classmethod via the unbound function on the class
118
+ return _orig_from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
119
+
120
+ # Swap in the monkey patch (but keep a handle to the original)
121
+ _orig_from_pretrained = MultitaskProteinModel.from_pretrained.__func__
122
+ MultitaskProteinModel.from_pretrained = classmethod(_local_aware_from_pretrained)
123
+ # --------------------------------------------------------------------
124
+
125
+ # Load config and model exactly like before (now it will use the local-aware loader)
126
+ config = AutoConfig.from_pretrained(metalatte_local, local_files_only=True)
127
  tokenizer = AutoTokenizer.from_pretrained(esm_local, local_files_only=True)
128
+ model = AutoModel.from_pretrained(metalatte_local, config=config, local_files_only=True)
129
+ model.eval()
130
 
131
  @torch.inference_mode()
132
  def predict(sequence):