File size: 5,055 Bytes
e66c5e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | #!/usr/bin/env python3
from __future__ import annotations
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from huggingface_hub import snapshot_download
from inference import (
PeptiVersePredictor,
read_best_manifest_csv,
canon_model,
)
# -----------------------------
# Config
# -----------------------------
MODEL_REPO = "ChatterjeeLab/PeptiVerse"
DEFAULT_ASSETS_DIR = Path("./") # where downloaded models live
DEFAULT_MANIFEST = Path("./basic_models.txt")
BANNED_MODELS = {"svm", "enet", "svm_gpu", "enet_gpu"}
def _norm_prop_disk(prop_key: str) -> str:
return "half_life" if prop_key == "halflife" else prop_key
def _resolve_expected_model_dir(prop_key: str, model_name: str, mode: str) -> str:
disk_prop = _norm_prop_disk(prop_key)
base = f"training_classifiers/{disk_prop}"
# binding affinity is special: its label is pooled/unpooled and folder uses wt_<mode>_<pooled|unpooled>
if prop_key == "binding_affinity":
pooled_or_unpooled = model_name # "pooled" or "unpooled"
return f"{base}/wt_{mode}_{pooled_or_unpooled}"
# halflife special folders
if prop_key == "halflife":
if model_name in {"xgb_wt_log", "xgb_smiles"}:
return f"{base}/{model_name}"
if mode == "wt" and model_name == "transformer":
return f"{base}/transformer_wt_log"
if model_name == "xgb":
return f"{base}/{'xgb_wt_log' if mode == 'wt' else 'xgb_smiles'}"
return f"{base}/{model_name}_{mode}"
def build_allow_patterns_from_manifest(manifest_path: Path) -> List[str]:
best = read_best_manifest_csv(manifest_path)
allow: List[str] = []
# For each property, fetch best artifacts for wt + smiles
for prop_key, row in best.items():
for mode, label in [("wt", row.best_wt), ("smiles", row.best_smiles)]:
m = canon_model(label)
if m is None:
continue
if m in BANNED_MODELS:
m = "xgb"
model_dir = _resolve_expected_model_dir(prop_key, m, mode)
# fetch only "basic" artifacts, not everything in the folder
allow += [
f"{model_dir}/best_model.json",
f"{model_dir}/best_model.pt",
f"{model_dir}/best_model*.joblib",
f"{model_dir}/best_model*.json",
]
seen = set()
out = []
for p in allow:
if p not in seen:
out.append(p)
seen.add(p)
return out
def download_assets(
repo_id: str,
manifest_path: Path,
out_dir: Path,
) -> Path:
out_dir = out_dir.resolve()
out_dir.mkdir(parents=True, exist_ok=True)
allow_patterns = build_allow_patterns_from_manifest(manifest_path)
snapshot_download(
repo_id=repo_id,
local_dir=str(out_dir),
local_dir_use_symlinks=False,
allow_patterns=allow_patterns,
)
return out_dir
# -----------------------------
# Main
# -----------------------------
def main():
import argparse
ap = argparse.ArgumentParser(description="Lightweight PeptiVerse inference with on-demand model download.")
ap.add_argument("--repo", default=MODEL_REPO, help="HF repo id containing weights/assets.")
ap.add_argument("--manifest", default=str(DEFAULT_MANIFEST), help="Path to best_models.txt")
ap.add_argument("--assets", default=str(DEFAULT_ASSETS_DIR), help="Where to store downloaded assets")
ap.add_argument("--device", default=None, help="cuda / cpu / cuda:0, etc")
ap.add_argument("--property", default="hemolysis", help="Property key (e.g. hemolysis, solubility, ...)")
ap.add_argument("--mode", default="wt", choices=["wt", "smiles"], help="Input type: wt=AA sequence, smiles=SMILES")
ap.add_argument("--input", default="GIGAVLKVLTTGLPALISWIKRKRQQ", help="Sequence or SMILES string")
ap.add_argument("--target_seq", default=None, help="Target WT sequence for binding_affinity")
ap.add_argument("--binder", default=None, help="Binder string (AA or SMILES) for binding_affinity")
args = ap.parse_args()
manifest_path = Path(args.manifest)
if not manifest_path.exists():
raise FileNotFoundError(f"Manifest not found: {manifest_path}")
assets_dir = download_assets(args.repo, manifest_path=manifest_path, out_dir=Path(args.assets))
""" OPTIONAL TEST CODE
predictor = PeptiVersePredictor(
manifest_path="basic_models.txt", # use the downloaded copy to be consistent
classifier_weight_root=str(assets_dir),
device=args.device,
)
if args.property == "binding_affinity":
if not args.target_seq or not args.binder:
raise ValueError("For binding_affinity, provide --target_seq and --binder.")
out = predictor.predict_binding_affinity(args.mode, target_seq=args.target_seq, binder_str=args.binder)
else:
out = predictor.predict_property(args.property, args.mode, args.input)
print(out)
"""
if __name__ == "__main__":
main()
|