ynuozhang
commited on
Commit
·
ba4d3fd
1
Parent(s):
6778ebd
fix path
Browse files- README.md +2 -2
- inference.py +67 -21
README.md
CHANGED
|
@@ -435,8 +435,8 @@ huggingface-cli download ChatterjeeLab/PeptiVerse \
|
|
| 435 |
--local-dir . \
|
| 436 |
--local-dir-use-symlinks False
|
| 437 |
```
|
| 438 |
-
###
|
| 439 |
-
|
| 440 |
|
| 441 |
## Citation
|
| 442 |
|
|
|
|
| 435 |
--local-dir . \
|
| 436 |
--local-dir-use-symlinks False
|
| 437 |
```
|
| 438 |
+
### Trouble installing cuML
|
| 439 |
+
For error related to cuda library, reinstall the `torch` after installing `cuML`.
|
| 440 |
|
| 441 |
## Citation
|
| 442 |
|
inference.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# peptiverse_infer.py
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import csv, re, json
|
|
@@ -14,7 +13,8 @@ import xgboost as xgb
|
|
| 14 |
|
| 15 |
from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
|
| 16 |
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
# -----------------------------
|
| 20 |
# Manifest
|
|
@@ -138,7 +138,7 @@ def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path
|
|
| 138 |
|
| 139 |
if art.suffix == ".json":
|
| 140 |
booster = xgb.Booster()
|
| 141 |
-
print(str(art))
|
| 142 |
booster.load_model(str(art))
|
| 143 |
return "xgb", booster, art
|
| 144 |
|
|
@@ -226,6 +226,41 @@ def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int:
|
|
| 226 |
return int(sd["proj.weight"].shape[1])
|
| 227 |
raise ValueError(model_name)
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
|
| 230 |
params = ckpt["best_params"]
|
| 231 |
sd = ckpt["state_dict"]
|
|
@@ -238,25 +273,30 @@ def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.devic
|
|
| 238 |
model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
|
| 239 |
layers=int(params["layers"]), dropout=dropout)
|
| 240 |
elif model_name == "transformer":
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
or params.get("hidden_dim")
|
| 245 |
-
)
|
| 246 |
if d_model is None:
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
)
|
| 251 |
-
|
| 252 |
-
model = TransformerHead(
|
| 253 |
-
in_dim=in_dim,
|
| 254 |
-
d_model=int(d_model),
|
| 255 |
-
nhead=int(params["nhead"]),
|
| 256 |
-
layers=int(params["layers"]),
|
| 257 |
-
ff=int(params.get("ff", 4 * int(d_model))),
|
| 258 |
-
dropout=dropout
|
| 259 |
-
)
|
| 260 |
else:
|
| 261 |
raise ValueError(f"Unknown NN model_name={model_name}")
|
| 262 |
|
|
@@ -678,6 +718,12 @@ class PeptiVersePredictor:
|
|
| 678 |
if d.exists():
|
| 679 |
return d
|
| 680 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
if prop_key == "halflife" and model_name == "xgb":
|
| 682 |
d = base / ("xgb_wt_log" if mode == "wt" else "xgb_smiles")
|
| 683 |
if d.exists():
|
|
@@ -920,7 +966,7 @@ class PeptiVersePredictor:
|
|
| 920 |
if __name__ == "__main__":
|
| 921 |
predictor = PeptiVersePredictor(
|
| 922 |
manifest_path="best_models.txt",
|
| 923 |
-
classifier_weight_root="./
|
| 924 |
)
|
| 925 |
print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ"))
|
| 926 |
print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="..."))
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import csv, re, json
|
|
|
|
| 13 |
|
| 14 |
from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
|
| 15 |
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 16 |
+
from lightning.pytorch import seed_everything
|
| 17 |
+
seed_everything(1986)
|
| 18 |
|
| 19 |
# -----------------------------
|
| 20 |
# Manifest
|
|
|
|
| 138 |
|
| 139 |
if art.suffix == ".json":
|
| 140 |
booster = xgb.Booster()
|
| 141 |
+
#print(str(art))
|
| 142 |
booster.load_model(str(art))
|
| 143 |
return "xgb", booster, art
|
| 144 |
|
|
|
|
| 226 |
return int(sd["proj.weight"].shape[1])
|
| 227 |
raise ValueError(model_name)
|
| 228 |
|
| 229 |
+
def _infer_num_layers_from_sd(sd: dict, prefix: str = "enc.layers.") -> int:
|
| 230 |
+
# enc.layers.0.*, enc.layers.1.*, ...
|
| 231 |
+
idxs = set()
|
| 232 |
+
for k in sd.keys():
|
| 233 |
+
if k.startswith(prefix):
|
| 234 |
+
rest = k[len(prefix):]
|
| 235 |
+
m = re.match(r"(\d+)\.", rest)
|
| 236 |
+
if m:
|
| 237 |
+
idxs.add(int(m.group(1)))
|
| 238 |
+
return (max(idxs) + 1) if idxs else 1
|
| 239 |
+
|
| 240 |
+
def _infer_transformer_arch_from_sd(sd: dict) -> Tuple[int, int, int]:
|
| 241 |
+
"""
|
| 242 |
+
Returns (d_model, layers, ff) inferred from weights.
|
| 243 |
+
- d_model from proj.weight (shape: [d_model, in_dim])
|
| 244 |
+
- layers from count of enc.layers.*
|
| 245 |
+
- ff from enc.layers.0.linear1.weight (shape: [ff, d_model])
|
| 246 |
+
"""
|
| 247 |
+
if "proj.weight" not in sd:
|
| 248 |
+
raise KeyError("Missing proj.weight in state_dict; cannot infer transformer d_model.")
|
| 249 |
+
d_model = int(sd["proj.weight"].shape[0])
|
| 250 |
+
layers = _infer_num_layers_from_sd(sd, prefix="enc.layers.")
|
| 251 |
+
if "enc.layers.0.linear1.weight" in sd:
|
| 252 |
+
ff = int(sd["enc.layers.0.linear1.weight"].shape[0])
|
| 253 |
+
else:
|
| 254 |
+
ff = 4 * d_model
|
| 255 |
+
return d_model, layers, ff
|
| 256 |
+
|
| 257 |
+
def _pick_nhead(d_model: int) -> int:
|
| 258 |
+
# prefer common head counts; must divide d_model
|
| 259 |
+
for h in (8, 6, 4, 3, 2, 1):
|
| 260 |
+
if d_model % h == 0:
|
| 261 |
+
return h
|
| 262 |
+
return 1
|
| 263 |
+
|
| 264 |
def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
|
| 265 |
params = ckpt["best_params"]
|
| 266 |
sd = ckpt["state_dict"]
|
|
|
|
| 273 |
model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
|
| 274 |
layers=int(params["layers"]), dropout=dropout)
|
| 275 |
elif model_name == "transformer":
|
| 276 |
+
# if transfer-learning ckpt omits arch params, infer from state_dict. special case for transformer_wt_log
|
| 277 |
+
d_model = params.get("d_model") or params.get("hidden") or params.get("hidden_dim")
|
| 278 |
+
|
|
|
|
|
|
|
| 279 |
if d_model is None:
|
| 280 |
+
d_model_i, layers_i, ff_i = _infer_transformer_arch_from_sd(sd)
|
| 281 |
+
nhead_i = _pick_nhead(d_model_i)
|
| 282 |
+
model = TransformerHead(
|
| 283 |
+
in_dim=in_dim,
|
| 284 |
+
d_model=int(d_model_i),
|
| 285 |
+
nhead=int(params.get("nhead", nhead_i)),
|
| 286 |
+
layers=int(params.get("layers", layers_i)),
|
| 287 |
+
ff=int(params.get("ff", ff_i)),
|
| 288 |
+
dropout=float(params.get("dropout", dropout)),
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
d_model = int(d_model)
|
| 292 |
+
model = TransformerHead(
|
| 293 |
+
in_dim=in_dim,
|
| 294 |
+
d_model=d_model,
|
| 295 |
+
nhead=int(params.get("nhead", _pick_nhead(d_model))),
|
| 296 |
+
layers=int(params.get("layers", 2)),
|
| 297 |
+
ff=int(params.get("ff", 4 * d_model)),
|
| 298 |
+
dropout=dropout
|
| 299 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
else:
|
| 301 |
raise ValueError(f"Unknown NN model_name={model_name}")
|
| 302 |
|
|
|
|
| 718 |
if d.exists():
|
| 719 |
return d
|
| 720 |
|
| 721 |
+
# special handling for halflife transformer wt log folder
|
| 722 |
+
if prop_key == "halflife" and mode == "wt" and model_name == "transformer":
|
| 723 |
+
d = base / "transformer_wt_log"
|
| 724 |
+
if d.exists():
|
| 725 |
+
return d
|
| 726 |
+
|
| 727 |
if prop_key == "halflife" and model_name == "xgb":
|
| 728 |
d = base / ("xgb_wt_log" if mode == "wt" else "xgb_smiles")
|
| 729 |
if d.exists():
|
|
|
|
| 966 |
if __name__ == "__main__":
|
| 967 |
predictor = PeptiVersePredictor(
|
| 968 |
manifest_path="best_models.txt",
|
| 969 |
+
classifier_weight_root="./"
|
| 970 |
)
|
| 971 |
print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ"))
|
| 972 |
print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="..."))
|