"""PETIMOT inference utilities for custom proteins.""" import os, sys import numpy as np from pathlib import Path # Ensure PETIMOT is importable PETIMOT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) if PETIMOT_ROOT not in sys.path: sys.path.insert(0, PETIMOT_ROOT) EMBEDDING_DIM_MAP = {"prostt5": 1024, "esmc_300m": 960, "esmc_600m": 1152} def run_inference(pdb_path: str, weights_path: str, config_path: str = None, output_dir: str = "/tmp/petimot_pred") -> dict: """Run PETIMOT inference on a single PDB file. Args: pdb_path: Path to input PDB file weights_path: Path to model weights .pt config_path: Path to config YAML (default: configs/default.yaml) output_dir: Where to save predictions Returns: dict with modes, ca_coords, seq, etc. """ try: import torch except ImportError: raise ImportError("PyTorch is required to run inference. Install it with: pip install torch") from petimot.infer.infer import infer from petimot.data.pdb_utils import load_backbone_coordinates if config_path is None: config_path = os.path.join(PETIMOT_ROOT, "configs", "default.yaml") os.makedirs(output_dir, exist_ok=True) # Run inference infer(model_path=weights_path, config_file=config_path, input_list=[pdb_path], output_path=output_dir) # Collect results stem = os.path.splitext(os.path.basename(weights_path))[0] pred_subdir = os.path.join(output_dir, stem) basename = os.path.splitext(os.path.basename(pdb_path))[0] # Load structure bb_data = load_backbone_coordinates(pdb_path, allow_hetatm=True) ca = bb_data["bb"][:, 1].numpy() seq = bb_data.get("seq", "X" * len(ca)) if not isinstance(seq, str): seq = "X" * len(ca) # Load predicted modes modes = {} for k in range(10): for pfx in [f"extracted_{basename}", basename]: mf = os.path.join(pred_subdir, f"{pfx}_mode_{k}.txt") if os.path.exists(mf): modes[k] = np.loadtxt(mf) break with open(pdb_path) as f: pdb_text = f.read() return { "name": basename, "ca_coords": ca, "seq": seq, "modes": modes, "pdb_text": pdb_text, "pred_dir": pred_subdir, "n_res": len(ca), } def download_pdb(pdb_id: str, output_dir: str = "/tmp/petimot_pdbs") -> str | None: """Download PDB from RCSB.""" import requests os.makedirs(output_dir, exist_ok=True) code4 = pdb_id[:4].lower() chain = pdb_id[4:].upper() if len(pdb_id) > 4 else "" out_path = os.path.join(output_dir, f"{pdb_id}.pdb") if os.path.exists(out_path): return out_path r = requests.get(f"https://files.rcsb.org/download/{code4}.pdb", timeout=30) if not r.ok: return None lines = r.text.split("\n") if chain: lines = [l for l in lines if (l.startswith("ATOM") and len(l) > 21 and l[21] == chain) or not l.startswith(("ATOM", "HETATM"))] with open(out_path, "w") as f: f.write("\n".join(lines)) return out_path