Petimot / petimot /data /pdb_utils.py
Valmbd's picture
Initial commit
474aa21
from typing import Dict, Tuple, List
import torch
from pathlib import Path
import os
def load_backbone_coordinates(
pdb_path: str,
allow_hetatm: bool = False,
) -> Tuple[torch.Tensor, ...]:
THREE_TO_ONE = {
"ALA": "A",
"ARG": "R",
"ASN": "N",
"ASP": "D",
"CYS": "C",
"GLN": "Q",
"GLU": "E",
"GLY": "G",
"HIS": "H",
"ILE": "I",
"LEU": "L",
"LYS": "K",
"MET": "M",
"PHE": "F",
"PRO": "P",
"SER": "S",
"THR": "T",
"TRP": "W",
"TYR": "Y",
"VAL": "V",
"SEC": "U",
"PYL": "O",
"UNK": "X",
# Non-standard and modified amino acids
"ABA": "A",
"ALY": "K",
"BFD": "D",
"CAF": "C",
"CAS": "C",
"CGU": "E",
"CME": "C",
"CSD": "C",
"CSO": "C",
"CSS": "C",
"CSX": "C",
"CXM": "M",
"DAL": "A",
"DCY": "C",
"DHA": "S",
"DLE": "L",
"DSN": "S",
"FME": "M",
"HIC": "H",
"HYP": "P",
"IAS": "D",
"KCX": "K",
"LLP": "K",
"M3L": "K",
"MDO": "A",
"MEN": "N",
"MEQ": "Q",
"MHO": "M",
"MLE": "L",
"MLY": "K",
"MLZ": "K",
"MSE": "M",
"MVA": "V",
"NEP": "H",
"NLE": "L",
"OCS": "C",
"PCA": "E",
"PHD": "D",
"PTR": "Y",
"SAR": "G",
"SCH": "C",
"SCY": "C",
"SEP": "S",
"SMC": "C",
"SME": "M",
"SNC": "C",
"TPO": "T",
"TYS": "Y",
"YCM": "C",
}
residues, types, numbers = [], [], []
current = {"coords": [], "type": None, "num": None}
try:
with open(pdb_path, "r") as f:
for line in f:
if not (
line.startswith("ATOM")
or (allow_hetatm and line.startswith("HETATM"))
):
continue
atom = line[12:16].strip()
res_type = line[17:20].strip()
if res_type == "HOH" or atom not in ["N", "CA", "C", "O"]:
continue
coords = torch.tensor(
[float(line[30:38]), float(line[38:46]), float(line[46:54])]
)
current["coords"].append(coords)
if not current["type"]:
current["type"] = res_type
current["num"] = int(line[22:26])
if len(current["coords"]) == 4:
residues.append(torch.stack(current["coords"]))
types.append(current["type"])
numbers.append(current["num"])
current = {"coords": [], "type": None, "num": None}
except FileNotFoundError:
raise FileNotFoundError(f"PDB file not found: {pdb_path}")
except Exception as e:
raise ValueError(f"Error parsing PDB file: {e}")
if not residues:
raise ValueError("No valid backbone atoms found")
backbone = torch.stack(residues)
output_path = f"extracted_{Path(pdb_path).name}"
if not os.path.exists(output_path):
with open(output_path, "w") as f:
atom_num = 1
res_num = 1
for res_idx, residue in enumerate(residues):
res_type = types[res_idx]
for atom_idx, (atom_name, coords) in enumerate(
zip(["N", "CA", "C", "O"], residue)
):
# Standard PDB format
f.write(
f"ATOM {atom_num:5d} {atom_name:<3s} {res_type:3s} A{res_num:4d} "
f"{coords[0]:8.3f}{coords[1]:8.3f}{coords[2]:8.3f}"
f" 1.00 0.00 {atom_name[0]:>2s}\n"
)
atom_num += 1
res_num += 1
f.write("END\n")
outputs = {"bb": backbone}
seq = "".join(THREE_TO_ONE.get(t, "X") for t in types)
outputs["seq"] = seq
outputs["residue_types"] = types
outputs["residue_numbers"] = numbers
return outputs