Joblib
File size: 5,055 Bytes
e66c5e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python3
from __future__ import annotations

import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple

from huggingface_hub import snapshot_download
from inference import (
    PeptiVersePredictor,
    read_best_manifest_csv,
    canon_model,
)

# -----------------------------
# Config
# -----------------------------
MODEL_REPO = "ChatterjeeLab/PeptiVerse"  
DEFAULT_ASSETS_DIR = Path("./")   # where downloaded models live
DEFAULT_MANIFEST = Path("./basic_models.txt")

BANNED_MODELS = {"svm", "enet", "svm_gpu", "enet_gpu"}


def _norm_prop_disk(prop_key: str) -> str:
    return "half_life" if prop_key == "halflife" else prop_key

def _resolve_expected_model_dir(prop_key: str, model_name: str, mode: str) -> str:
    disk_prop = _norm_prop_disk(prop_key)
    base = f"training_classifiers/{disk_prop}"

    # binding affinity is special: its label is pooled/unpooled and folder uses wt_<mode>_<pooled|unpooled>
    if prop_key == "binding_affinity":
        pooled_or_unpooled = model_name  # "pooled" or "unpooled"
        return f"{base}/wt_{mode}_{pooled_or_unpooled}"

    # halflife special folders
    if prop_key == "halflife":
        if model_name in {"xgb_wt_log", "xgb_smiles"}:
            return f"{base}/{model_name}"
        if mode == "wt" and model_name == "transformer":
            return f"{base}/transformer_wt_log"
        if model_name == "xgb":
            return f"{base}/{'xgb_wt_log' if mode == 'wt' else 'xgb_smiles'}"

    return f"{base}/{model_name}_{mode}"


def build_allow_patterns_from_manifest(manifest_path: Path) -> List[str]:
    best = read_best_manifest_csv(manifest_path)

    allow: List[str] = []

    # For each property, fetch best artifacts for wt + smiles
    for prop_key, row in best.items():
        for mode, label in [("wt", row.best_wt), ("smiles", row.best_smiles)]:
            m = canon_model(label)
            if m is None:
                continue

            if m in BANNED_MODELS:
                m = "xgb"

            model_dir = _resolve_expected_model_dir(prop_key, m, mode)

            # fetch only "basic" artifacts, not everything in the folder
            allow += [
                f"{model_dir}/best_model.json",
                f"{model_dir}/best_model.pt",
                f"{model_dir}/best_model*.joblib",
                f"{model_dir}/best_model*.json",
            ]

    seen = set()
    out = []
    for p in allow:
        if p not in seen:
            out.append(p)
            seen.add(p)
    return out


def download_assets(
    repo_id: str,
    manifest_path: Path,
    out_dir: Path,
) -> Path:
    out_dir = out_dir.resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    allow_patterns = build_allow_patterns_from_manifest(manifest_path)

    snapshot_download(
        repo_id=repo_id,
        local_dir=str(out_dir),
        local_dir_use_symlinks=False,
        allow_patterns=allow_patterns,
    )
    return out_dir


# -----------------------------
# Main
# -----------------------------
def main():
    import argparse

    ap = argparse.ArgumentParser(description="Lightweight PeptiVerse inference with on-demand model download.")
    ap.add_argument("--repo", default=MODEL_REPO, help="HF repo id containing weights/assets.")
    ap.add_argument("--manifest", default=str(DEFAULT_MANIFEST), help="Path to best_models.txt")
    ap.add_argument("--assets", default=str(DEFAULT_ASSETS_DIR), help="Where to store downloaded assets")
    ap.add_argument("--device", default=None, help="cuda / cpu / cuda:0, etc")

    ap.add_argument("--property", default="hemolysis", help="Property key (e.g. hemolysis, solubility, ...)")
    ap.add_argument("--mode", default="wt", choices=["wt", "smiles"], help="Input type: wt=AA sequence, smiles=SMILES")
    ap.add_argument("--input", default="GIGAVLKVLTTGLPALISWIKRKRQQ", help="Sequence or SMILES string")
    ap.add_argument("--target_seq", default=None, help="Target WT sequence for binding_affinity")
    ap.add_argument("--binder", default=None, help="Binder string (AA or SMILES) for binding_affinity")
    args = ap.parse_args()

    manifest_path = Path(args.manifest)
    if not manifest_path.exists():
        raise FileNotFoundError(f"Manifest not found: {manifest_path}")

    assets_dir = download_assets(args.repo, manifest_path=manifest_path, out_dir=Path(args.assets))

    """ OPTIONAL TEST CODE
    predictor = PeptiVersePredictor(
        manifest_path="basic_models.txt",  # use the downloaded copy to be consistent
        classifier_weight_root=str(assets_dir),
        device=args.device,
    )

    if args.property == "binding_affinity":
        if not args.target_seq or not args.binder:
            raise ValueError("For binding_affinity, provide --target_seq and --binder.")
        out = predictor.predict_binding_affinity(args.mode, target_seq=args.target_seq, binder_str=args.binder)
    else:
        out = predictor.predict_property(args.property, args.mode, args.input)

    print(out)
    """

if __name__ == "__main__":
    main()