| | |
| | from __future__ import annotations |
| |
|
| | import os |
| | from pathlib import Path |
| | from typing import Dict, List, Optional, Tuple |
| |
|
| | from huggingface_hub import snapshot_download |
| | from inference import ( |
| | PeptiVersePredictor, |
| | read_best_manifest_csv, |
| | canon_model, |
| | ) |
| |
|
| | |
| | |
| | |
| | MODEL_REPO = "ChatterjeeLab/PeptiVerse" |
| | DEFAULT_ASSETS_DIR = Path("./") |
| | DEFAULT_MANIFEST = Path("./basic_models.txt") |
| |
|
| | BANNED_MODELS = {"svm", "enet", "svm_gpu", "enet_gpu"} |
| |
|
| |
|
| | def _norm_prop_disk(prop_key: str) -> str: |
| | return "half_life" if prop_key == "halflife" else prop_key |
| |
|
| | def _resolve_expected_model_dir(prop_key: str, model_name: str, mode: str) -> str: |
| | disk_prop = _norm_prop_disk(prop_key) |
| | base = f"training_classifiers/{disk_prop}" |
| |
|
| | |
| | if prop_key == "binding_affinity": |
| | pooled_or_unpooled = model_name |
| | return f"{base}/wt_{mode}_{pooled_or_unpooled}" |
| |
|
| | |
| | if prop_key == "halflife": |
| | if model_name in {"xgb_wt_log", "xgb_smiles"}: |
| | return f"{base}/{model_name}" |
| | if mode == "wt" and model_name == "transformer": |
| | return f"{base}/transformer_wt_log" |
| | if model_name == "xgb": |
| | return f"{base}/{'xgb_wt_log' if mode == 'wt' else 'xgb_smiles'}" |
| |
|
| | return f"{base}/{model_name}_{mode}" |
| |
|
| |
|
| | def build_allow_patterns_from_manifest(manifest_path: Path) -> List[str]: |
| | best = read_best_manifest_csv(manifest_path) |
| |
|
| | allow: List[str] = [] |
| |
|
| | |
| | for prop_key, row in best.items(): |
| | for mode, label in [("wt", row.best_wt), ("smiles", row.best_smiles)]: |
| | m = canon_model(label) |
| | if m is None: |
| | continue |
| |
|
| | if m in BANNED_MODELS: |
| | m = "xgb" |
| |
|
| | model_dir = _resolve_expected_model_dir(prop_key, m, mode) |
| |
|
| | |
| | allow += [ |
| | f"{model_dir}/best_model.json", |
| | f"{model_dir}/best_model.pt", |
| | f"{model_dir}/best_model*.joblib", |
| | f"{model_dir}/best_model*.json", |
| | ] |
| |
|
| | seen = set() |
| | out = [] |
| | for p in allow: |
| | if p not in seen: |
| | out.append(p) |
| | seen.add(p) |
| | return out |
| |
|
| |
|
| | def download_assets( |
| | repo_id: str, |
| | manifest_path: Path, |
| | out_dir: Path, |
| | ) -> Path: |
| | out_dir = out_dir.resolve() |
| | out_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | allow_patterns = build_allow_patterns_from_manifest(manifest_path) |
| |
|
| | snapshot_download( |
| | repo_id=repo_id, |
| | local_dir=str(out_dir), |
| | local_dir_use_symlinks=False, |
| | allow_patterns=allow_patterns, |
| | ) |
| | return out_dir |
| |
|
| |
|
| | |
| | |
| | |
| | def main(): |
| | import argparse |
| |
|
| | ap = argparse.ArgumentParser(description="Lightweight PeptiVerse inference with on-demand model download.") |
| | ap.add_argument("--repo", default=MODEL_REPO, help="HF repo id containing weights/assets.") |
| | ap.add_argument("--manifest", default=str(DEFAULT_MANIFEST), help="Path to best_models.txt") |
| | ap.add_argument("--assets", default=str(DEFAULT_ASSETS_DIR), help="Where to store downloaded assets") |
| | ap.add_argument("--device", default=None, help="cuda / cpu / cuda:0, etc") |
| |
|
| | ap.add_argument("--property", default="hemolysis", help="Property key (e.g. hemolysis, solubility, ...)") |
| | ap.add_argument("--mode", default="wt", choices=["wt", "smiles"], help="Input type: wt=AA sequence, smiles=SMILES") |
| | ap.add_argument("--input", default="GIGAVLKVLTTGLPALISWIKRKRQQ", help="Sequence or SMILES string") |
| | ap.add_argument("--target_seq", default=None, help="Target WT sequence for binding_affinity") |
| | ap.add_argument("--binder", default=None, help="Binder string (AA or SMILES) for binding_affinity") |
| | args = ap.parse_args() |
| |
|
| | manifest_path = Path(args.manifest) |
| | if not manifest_path.exists(): |
| | raise FileNotFoundError(f"Manifest not found: {manifest_path}") |
| |
|
| | assets_dir = download_assets(args.repo, manifest_path=manifest_path, out_dir=Path(args.assets)) |
| |
|
| | """ OPTIONAL TEST CODE |
| | predictor = PeptiVersePredictor( |
| | manifest_path="basic_models.txt", # use the downloaded copy to be consistent |
| | classifier_weight_root=str(assets_dir), |
| | device=args.device, |
| | ) |
| | |
| | if args.property == "binding_affinity": |
| | if not args.target_seq or not args.binder: |
| | raise ValueError("For binding_affinity, provide --target_seq and --binder.") |
| | out = predictor.predict_binding_affinity(args.mode, target_seq=args.target_seq, binder_str=args.binder) |
| | else: |
| | out = predictor.predict_property(args.property, args.mode, args.input) |
| | |
| | print(out) |
| | """ |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|