Petimot / app /utils /inference.py
Valmbd's picture
Fix: lazy torch import in inference.py — page loads without torch installed
189f8d9 verified
"""PETIMOT inference utilities for custom proteins."""
import os, sys
import numpy as np
from pathlib import Path
# Ensure PETIMOT is importable
PETIMOT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if PETIMOT_ROOT not in sys.path:
sys.path.insert(0, PETIMOT_ROOT)
EMBEDDING_DIM_MAP = {"prostt5": 1024, "esmc_300m": 960, "esmc_600m": 1152}
def run_inference(pdb_path: str, weights_path: str, config_path: str = None,
output_dir: str = "/tmp/petimot_pred") -> dict:
"""Run PETIMOT inference on a single PDB file.
Args:
pdb_path: Path to input PDB file
weights_path: Path to model weights .pt
config_path: Path to config YAML (default: configs/default.yaml)
output_dir: Where to save predictions
Returns:
dict with modes, ca_coords, seq, etc.
"""
try:
import torch
except ImportError:
raise ImportError("PyTorch is required to run inference. Install it with: pip install torch")
from petimot.infer.infer import infer
from petimot.data.pdb_utils import load_backbone_coordinates
if config_path is None:
config_path = os.path.join(PETIMOT_ROOT, "configs", "default.yaml")
os.makedirs(output_dir, exist_ok=True)
# Run inference
infer(model_path=weights_path, config_file=config_path,
input_list=[pdb_path], output_path=output_dir)
# Collect results
stem = os.path.splitext(os.path.basename(weights_path))[0]
pred_subdir = os.path.join(output_dir, stem)
basename = os.path.splitext(os.path.basename(pdb_path))[0]
# Load structure
bb_data = load_backbone_coordinates(pdb_path, allow_hetatm=True)
ca = bb_data["bb"][:, 1].numpy()
seq = bb_data.get("seq", "X" * len(ca))
if not isinstance(seq, str):
seq = "X" * len(ca)
# Load predicted modes
modes = {}
for k in range(10):
for pfx in [f"extracted_{basename}", basename]:
mf = os.path.join(pred_subdir, f"{pfx}_mode_{k}.txt")
if os.path.exists(mf):
modes[k] = np.loadtxt(mf)
break
with open(pdb_path) as f:
pdb_text = f.read()
return {
"name": basename,
"ca_coords": ca,
"seq": seq,
"modes": modes,
"pdb_text": pdb_text,
"pred_dir": pred_subdir,
"n_res": len(ca),
}
def download_pdb(pdb_id: str, output_dir: str = "/tmp/petimot_pdbs") -> str | None:
"""Download PDB from RCSB."""
import requests
os.makedirs(output_dir, exist_ok=True)
code4 = pdb_id[:4].lower()
chain = pdb_id[4:].upper() if len(pdb_id) > 4 else ""
out_path = os.path.join(output_dir, f"{pdb_id}.pdb")
if os.path.exists(out_path):
return out_path
r = requests.get(f"https://files.rcsb.org/download/{code4}.pdb", timeout=30)
if not r.ok:
return None
lines = r.text.split("\n")
if chain:
lines = [l for l in lines
if (l.startswith("ATOM") and len(l) > 21 and l[21] == chain)
or not l.startswith(("ATOM", "HETATM"))]
with open(out_path, "w") as f:
f.write("\n".join(lines))
return out_path