File size: 3,214 Bytes
b47954d
189f8d9
b47954d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189f8d9
 
 
 
 
b47954d
 
 
189f8d9
b47954d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""PETIMOT inference utilities for custom proteins."""
import os, sys
import numpy as np
from pathlib import Path

# Ensure PETIMOT is importable
PETIMOT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if PETIMOT_ROOT not in sys.path:
    sys.path.insert(0, PETIMOT_ROOT)

EMBEDDING_DIM_MAP = {"prostt5": 1024, "esmc_300m": 960, "esmc_600m": 1152}


def run_inference(pdb_path: str, weights_path: str, config_path: str = None,
                  output_dir: str = "/tmp/petimot_pred") -> dict:
    """Run PETIMOT inference on a single PDB file.

    Args:
        pdb_path: Path to input PDB file
        weights_path: Path to model weights .pt
        config_path: Path to config YAML (default: configs/default.yaml)
        output_dir: Where to save predictions

    Returns:
        dict with modes, ca_coords, seq, etc.
    """
    try:
        import torch
    except ImportError:
        raise ImportError("PyTorch is required to run inference. Install it with: pip install torch")

    from petimot.infer.infer import infer
    from petimot.data.pdb_utils import load_backbone_coordinates


    if config_path is None:
        config_path = os.path.join(PETIMOT_ROOT, "configs", "default.yaml")

    os.makedirs(output_dir, exist_ok=True)

    # Run inference
    infer(model_path=weights_path, config_file=config_path,
          input_list=[pdb_path], output_path=output_dir)

    # Collect results
    stem = os.path.splitext(os.path.basename(weights_path))[0]
    pred_subdir = os.path.join(output_dir, stem)
    basename = os.path.splitext(os.path.basename(pdb_path))[0]

    # Load structure
    bb_data = load_backbone_coordinates(pdb_path, allow_hetatm=True)
    ca = bb_data["bb"][:, 1].numpy()
    seq = bb_data.get("seq", "X" * len(ca))
    if not isinstance(seq, str):
        seq = "X" * len(ca)

    # Load predicted modes
    modes = {}
    for k in range(10):
        for pfx in [f"extracted_{basename}", basename]:
            mf = os.path.join(pred_subdir, f"{pfx}_mode_{k}.txt")
            if os.path.exists(mf):
                modes[k] = np.loadtxt(mf)
                break

    with open(pdb_path) as f:
        pdb_text = f.read()

    return {
        "name": basename,
        "ca_coords": ca,
        "seq": seq,
        "modes": modes,
        "pdb_text": pdb_text,
        "pred_dir": pred_subdir,
        "n_res": len(ca),
    }


def download_pdb(pdb_id: str, output_dir: str = "/tmp/petimot_pdbs") -> str | None:
    """Download PDB from RCSB."""
    import requests

    os.makedirs(output_dir, exist_ok=True)
    code4 = pdb_id[:4].lower()
    chain = pdb_id[4:].upper() if len(pdb_id) > 4 else ""
    out_path = os.path.join(output_dir, f"{pdb_id}.pdb")

    if os.path.exists(out_path):
        return out_path

    r = requests.get(f"https://files.rcsb.org/download/{code4}.pdb", timeout=30)
    if not r.ok:
        return None

    lines = r.text.split("\n")
    if chain:
        lines = [l for l in lines
                 if (l.startswith("ATOM") and len(l) > 21 and l[21] == chain)
                 or not l.startswith(("ATOM", "HETATM"))]

    with open(out_path, "w") as f:
        f.write("\n".join(lines))
    return out_path