| from typing import Dict, Tuple, List |
| import torch |
| from pathlib import Path |
| import os |
|
|
|
|
| def load_backbone_coordinates( |
| pdb_path: str, |
| allow_hetatm: bool = False, |
| ) -> Tuple[torch.Tensor, ...]: |
|
|
| THREE_TO_ONE = { |
| "ALA": "A", |
| "ARG": "R", |
| "ASN": "N", |
| "ASP": "D", |
| "CYS": "C", |
| "GLN": "Q", |
| "GLU": "E", |
| "GLY": "G", |
| "HIS": "H", |
| "ILE": "I", |
| "LEU": "L", |
| "LYS": "K", |
| "MET": "M", |
| "PHE": "F", |
| "PRO": "P", |
| "SER": "S", |
| "THR": "T", |
| "TRP": "W", |
| "TYR": "Y", |
| "VAL": "V", |
| "SEC": "U", |
| "PYL": "O", |
| "UNK": "X", |
| |
| "ABA": "A", |
| "ALY": "K", |
| "BFD": "D", |
| "CAF": "C", |
| "CAS": "C", |
| "CGU": "E", |
| "CME": "C", |
| "CSD": "C", |
| "CSO": "C", |
| "CSS": "C", |
| "CSX": "C", |
| "CXM": "M", |
| "DAL": "A", |
| "DCY": "C", |
| "DHA": "S", |
| "DLE": "L", |
| "DSN": "S", |
| "FME": "M", |
| "HIC": "H", |
| "HYP": "P", |
| "IAS": "D", |
| "KCX": "K", |
| "LLP": "K", |
| "M3L": "K", |
| "MDO": "A", |
| "MEN": "N", |
| "MEQ": "Q", |
| "MHO": "M", |
| "MLE": "L", |
| "MLY": "K", |
| "MLZ": "K", |
| "MSE": "M", |
| "MVA": "V", |
| "NEP": "H", |
| "NLE": "L", |
| "OCS": "C", |
| "PCA": "E", |
| "PHD": "D", |
| "PTR": "Y", |
| "SAR": "G", |
| "SCH": "C", |
| "SCY": "C", |
| "SEP": "S", |
| "SMC": "C", |
| "SME": "M", |
| "SNC": "C", |
| "TPO": "T", |
| "TYS": "Y", |
| "YCM": "C", |
| } |
|
|
| residues, types, numbers = [], [], [] |
| current = {"coords": [], "type": None, "num": None} |
|
|
| try: |
| with open(pdb_path, "r") as f: |
| for line in f: |
| if not ( |
| line.startswith("ATOM") |
| or (allow_hetatm and line.startswith("HETATM")) |
| ): |
| continue |
|
|
| atom = line[12:16].strip() |
| res_type = line[17:20].strip() |
|
|
| if res_type == "HOH" or atom not in ["N", "CA", "C", "O"]: |
| continue |
|
|
| coords = torch.tensor( |
| [float(line[30:38]), float(line[38:46]), float(line[46:54])] |
| ) |
|
|
| current["coords"].append(coords) |
|
|
| if not current["type"]: |
| current["type"] = res_type |
| current["num"] = int(line[22:26]) |
|
|
| if len(current["coords"]) == 4: |
| residues.append(torch.stack(current["coords"])) |
| types.append(current["type"]) |
| numbers.append(current["num"]) |
| current = {"coords": [], "type": None, "num": None} |
|
|
| except FileNotFoundError: |
| raise FileNotFoundError(f"PDB file not found: {pdb_path}") |
| except Exception as e: |
| raise ValueError(f"Error parsing PDB file: {e}") |
|
|
| if not residues: |
| raise ValueError("No valid backbone atoms found") |
|
|
| backbone = torch.stack(residues) |
|
|
| output_path = f"extracted_{Path(pdb_path).name}" |
| if not os.path.exists(output_path): |
| with open(output_path, "w") as f: |
| atom_num = 1 |
| res_num = 1 |
| for res_idx, residue in enumerate(residues): |
| res_type = types[res_idx] |
| for atom_idx, (atom_name, coords) in enumerate( |
| zip(["N", "CA", "C", "O"], residue) |
| ): |
| |
| f.write( |
| f"ATOM {atom_num:5d} {atom_name:<3s} {res_type:3s} A{res_num:4d} " |
| f"{coords[0]:8.3f}{coords[1]:8.3f}{coords[2]:8.3f}" |
| f" 1.00 0.00 {atom_name[0]:>2s}\n" |
| ) |
| atom_num += 1 |
| res_num += 1 |
| f.write("END\n") |
|
|
| outputs = {"bb": backbone} |
| seq = "".join(THREE_TO_ONE.get(t, "X") for t in types) |
| outputs["seq"] = seq |
| outputs["residue_types"] = types |
| outputs["residue_numbers"] = numbers |
| return outputs |
|
|