Joblib
PeptiVerse / download_light.py
ynuozhang
add light install
e66c5e2
#!/usr/bin/env python3
from __future__ import annotations
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from huggingface_hub import snapshot_download
from inference import (
PeptiVersePredictor,
read_best_manifest_csv,
canon_model,
)
# -----------------------------
# Config
# -----------------------------
MODEL_REPO = "ChatterjeeLab/PeptiVerse"
DEFAULT_ASSETS_DIR = Path("./") # where downloaded models live
DEFAULT_MANIFEST = Path("./basic_models.txt")
BANNED_MODELS = {"svm", "enet", "svm_gpu", "enet_gpu"}
def _norm_prop_disk(prop_key: str) -> str:
return "half_life" if prop_key == "halflife" else prop_key
def _resolve_expected_model_dir(prop_key: str, model_name: str, mode: str) -> str:
disk_prop = _norm_prop_disk(prop_key)
base = f"training_classifiers/{disk_prop}"
# binding affinity is special: its label is pooled/unpooled and folder uses wt_<mode>_<pooled|unpooled>
if prop_key == "binding_affinity":
pooled_or_unpooled = model_name # "pooled" or "unpooled"
return f"{base}/wt_{mode}_{pooled_or_unpooled}"
# halflife special folders
if prop_key == "halflife":
if model_name in {"xgb_wt_log", "xgb_smiles"}:
return f"{base}/{model_name}"
if mode == "wt" and model_name == "transformer":
return f"{base}/transformer_wt_log"
if model_name == "xgb":
return f"{base}/{'xgb_wt_log' if mode == 'wt' else 'xgb_smiles'}"
return f"{base}/{model_name}_{mode}"
def build_allow_patterns_from_manifest(manifest_path: Path) -> List[str]:
best = read_best_manifest_csv(manifest_path)
allow: List[str] = []
# For each property, fetch best artifacts for wt + smiles
for prop_key, row in best.items():
for mode, label in [("wt", row.best_wt), ("smiles", row.best_smiles)]:
m = canon_model(label)
if m is None:
continue
if m in BANNED_MODELS:
m = "xgb"
model_dir = _resolve_expected_model_dir(prop_key, m, mode)
# fetch only "basic" artifacts, not everything in the folder
allow += [
f"{model_dir}/best_model.json",
f"{model_dir}/best_model.pt",
f"{model_dir}/best_model*.joblib",
f"{model_dir}/best_model*.json",
]
seen = set()
out = []
for p in allow:
if p not in seen:
out.append(p)
seen.add(p)
return out
def download_assets(
repo_id: str,
manifest_path: Path,
out_dir: Path,
) -> Path:
out_dir = out_dir.resolve()
out_dir.mkdir(parents=True, exist_ok=True)
allow_patterns = build_allow_patterns_from_manifest(manifest_path)
snapshot_download(
repo_id=repo_id,
local_dir=str(out_dir),
local_dir_use_symlinks=False,
allow_patterns=allow_patterns,
)
return out_dir
# -----------------------------
# Main
# -----------------------------
def main():
import argparse
ap = argparse.ArgumentParser(description="Lightweight PeptiVerse inference with on-demand model download.")
ap.add_argument("--repo", default=MODEL_REPO, help="HF repo id containing weights/assets.")
ap.add_argument("--manifest", default=str(DEFAULT_MANIFEST), help="Path to best_models.txt")
ap.add_argument("--assets", default=str(DEFAULT_ASSETS_DIR), help="Where to store downloaded assets")
ap.add_argument("--device", default=None, help="cuda / cpu / cuda:0, etc")
ap.add_argument("--property", default="hemolysis", help="Property key (e.g. hemolysis, solubility, ...)")
ap.add_argument("--mode", default="wt", choices=["wt", "smiles"], help="Input type: wt=AA sequence, smiles=SMILES")
ap.add_argument("--input", default="GIGAVLKVLTTGLPALISWIKRKRQQ", help="Sequence or SMILES string")
ap.add_argument("--target_seq", default=None, help="Target WT sequence for binding_affinity")
ap.add_argument("--binder", default=None, help="Binder string (AA or SMILES) for binding_affinity")
args = ap.parse_args()
manifest_path = Path(args.manifest)
if not manifest_path.exists():
raise FileNotFoundError(f"Manifest not found: {manifest_path}")
assets_dir = download_assets(args.repo, manifest_path=manifest_path, out_dir=Path(args.assets))
""" OPTIONAL TEST CODE
predictor = PeptiVersePredictor(
manifest_path="basic_models.txt", # use the downloaded copy to be consistent
classifier_weight_root=str(assets_dir),
device=args.device,
)
if args.property == "binding_affinity":
if not args.target_seq or not args.binder:
raise ValueError("For binding_affinity, provide --target_seq and --binder.")
out = predictor.predict_binding_affinity(args.mode, target_seq=args.target_seq, binder_str=args.binder)
else:
out = predictor.predict_property(args.property, args.mode, args.input)
print(out)
"""
if __name__ == "__main__":
main()