Joblib
ynuozhang commited on
Commit
ba4d3fd
·
1 Parent(s): 6778ebd
Files changed (2) hide show
  1. README.md +2 -2
  2. inference.py +67 -21
README.md CHANGED
@@ -435,8 +435,8 @@ huggingface-cli download ChatterjeeLab/PeptiVerse \
435
  --local-dir . \
436
  --local-dir-use-symlinks False
437
  ```
438
- ### TODOs
439
- Bug loading transformer half-life model now, will fix soon.
440
 
441
  ## Citation
442
 
 
435
  --local-dir . \
436
  --local-dir-use-symlinks False
437
  ```
438
+ ### Trouble installing cuML
439
+ For error related to cuda library, reinstall the `torch` after installing `cuML`.
440
 
441
  ## Citation
442
 
inference.py CHANGED
@@ -1,4 +1,3 @@
1
- # peptiverse_infer.py
2
  from __future__ import annotations
3
 
4
  import csv, re, json
@@ -14,7 +13,8 @@ import xgboost as xgb
14
 
15
  from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
16
  from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
17
-
 
18
 
19
  # -----------------------------
20
  # Manifest
@@ -138,7 +138,7 @@ def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path
138
 
139
  if art.suffix == ".json":
140
  booster = xgb.Booster()
141
- print(str(art))
142
  booster.load_model(str(art))
143
  return "xgb", booster, art
144
 
@@ -226,6 +226,41 @@ def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int:
226
  return int(sd["proj.weight"].shape[1])
227
  raise ValueError(model_name)
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
230
  params = ckpt["best_params"]
231
  sd = ckpt["state_dict"]
@@ -238,25 +273,30 @@ def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.devic
238
  model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
239
  layers=int(params["layers"]), dropout=dropout)
240
  elif model_name == "transformer":
241
- d_model = (
242
- params.get("d_model")
243
- or params.get("hidden")
244
- or params.get("hidden_dim")
245
- )
246
  if d_model is None:
247
- raise KeyError(
248
- f"Transformer checkpoint missing d_model/hidden. "
249
- f"Available keys: {list(params.keys())}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  )
251
-
252
- model = TransformerHead(
253
- in_dim=in_dim,
254
- d_model=int(d_model),
255
- nhead=int(params["nhead"]),
256
- layers=int(params["layers"]),
257
- ff=int(params.get("ff", 4 * int(d_model))),
258
- dropout=dropout
259
- )
260
  else:
261
  raise ValueError(f"Unknown NN model_name={model_name}")
262
 
@@ -678,6 +718,12 @@ class PeptiVersePredictor:
678
  if d.exists():
679
  return d
680
 
 
 
 
 
 
 
681
  if prop_key == "halflife" and model_name == "xgb":
682
  d = base / ("xgb_wt_log" if mode == "wt" else "xgb_smiles")
683
  if d.exists():
@@ -920,7 +966,7 @@ class PeptiVersePredictor:
920
  if __name__ == "__main__":
921
  predictor = PeptiVersePredictor(
922
  manifest_path="best_models.txt",
923
- classifier_weight_root="./Classifier_Weight"
924
  )
925
  print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ"))
926
  print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="..."))
 
 
1
  from __future__ import annotations
2
 
3
  import csv, re, json
 
13
 
14
  from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
15
  from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
16
+ from lightning.pytorch import seed_everything
17
+ seed_everything(1986)
18
 
19
  # -----------------------------
20
  # Manifest
 
138
 
139
  if art.suffix == ".json":
140
  booster = xgb.Booster()
141
+ #print(str(art))
142
  booster.load_model(str(art))
143
  return "xgb", booster, art
144
 
 
226
  return int(sd["proj.weight"].shape[1])
227
  raise ValueError(model_name)
228
 
229
+ def _infer_num_layers_from_sd(sd: dict, prefix: str = "enc.layers.") -> int:
230
+ # enc.layers.0.*, enc.layers.1.*, ...
231
+ idxs = set()
232
+ for k in sd.keys():
233
+ if k.startswith(prefix):
234
+ rest = k[len(prefix):]
235
+ m = re.match(r"(\d+)\.", rest)
236
+ if m:
237
+ idxs.add(int(m.group(1)))
238
+ return (max(idxs) + 1) if idxs else 1
239
+
240
+ def _infer_transformer_arch_from_sd(sd: dict) -> Tuple[int, int, int]:
241
+ """
242
+ Returns (d_model, layers, ff) inferred from weights.
243
+ - d_model from proj.weight (shape: [d_model, in_dim])
244
+ - layers from count of enc.layers.*
245
+ - ff from enc.layers.0.linear1.weight (shape: [ff, d_model])
246
+ """
247
+ if "proj.weight" not in sd:
248
+ raise KeyError("Missing proj.weight in state_dict; cannot infer transformer d_model.")
249
+ d_model = int(sd["proj.weight"].shape[0])
250
+ layers = _infer_num_layers_from_sd(sd, prefix="enc.layers.")
251
+ if "enc.layers.0.linear1.weight" in sd:
252
+ ff = int(sd["enc.layers.0.linear1.weight"].shape[0])
253
+ else:
254
+ ff = 4 * d_model
255
+ return d_model, layers, ff
256
+
257
+ def _pick_nhead(d_model: int) -> int:
258
+ # prefer common head counts; must divide d_model
259
+ for h in (8, 6, 4, 3, 2, 1):
260
+ if d_model % h == 0:
261
+ return h
262
+ return 1
263
+
264
  def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
265
  params = ckpt["best_params"]
266
  sd = ckpt["state_dict"]
 
273
  model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
274
  layers=int(params["layers"]), dropout=dropout)
275
  elif model_name == "transformer":
276
+ # if transfer-learning ckpt omits arch params, infer from state_dict. special case for transformer_wt_log
277
+ d_model = params.get("d_model") or params.get("hidden") or params.get("hidden_dim")
278
+
 
 
279
  if d_model is None:
280
+ d_model_i, layers_i, ff_i = _infer_transformer_arch_from_sd(sd)
281
+ nhead_i = _pick_nhead(d_model_i)
282
+ model = TransformerHead(
283
+ in_dim=in_dim,
284
+ d_model=int(d_model_i),
285
+ nhead=int(params.get("nhead", nhead_i)),
286
+ layers=int(params.get("layers", layers_i)),
287
+ ff=int(params.get("ff", ff_i)),
288
+ dropout=float(params.get("dropout", dropout)),
289
+ )
290
+ else:
291
+ d_model = int(d_model)
292
+ model = TransformerHead(
293
+ in_dim=in_dim,
294
+ d_model=d_model,
295
+ nhead=int(params.get("nhead", _pick_nhead(d_model))),
296
+ layers=int(params.get("layers", 2)),
297
+ ff=int(params.get("ff", 4 * d_model)),
298
+ dropout=dropout
299
  )
 
 
 
 
 
 
 
 
 
300
  else:
301
  raise ValueError(f"Unknown NN model_name={model_name}")
302
 
 
718
  if d.exists():
719
  return d
720
 
721
+ # special handling for halflife transformer wt log folder
722
+ if prop_key == "halflife" and mode == "wt" and model_name == "transformer":
723
+ d = base / "transformer_wt_log"
724
+ if d.exists():
725
+ return d
726
+
727
  if prop_key == "halflife" and model_name == "xgb":
728
  d = base / ("xgb_wt_log" if mode == "wt" else "xgb_smiles")
729
  if d.exists():
 
966
  if __name__ == "__main__":
967
  predictor = PeptiVersePredictor(
968
  manifest_path="best_models.txt",
969
+ classifier_weight_root="./"
970
  )
971
  print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ"))
972
  print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="..."))