File size: 6,780 Bytes
f34af6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | """Alignment helpers for overlaying generated backbone and folded structures."""
from __future__ import annotations
import io
import logging
import string
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Tuple
import numpy as np
from Bio.PDB import PDBIO, PDBParser
from Bio.PDB.Chain import Chain
from Bio.PDB.Structure import Structure
from utils import new_pdbUtils as du
LOGGER = logging.getLogger(__name__)
_CHAIN_POOL = string.ascii_uppercase + string.digits + string.ascii_lowercase
class AlignmentError(RuntimeError):
"""Raised when aligned-compare preparation fails."""
@dataclass
class AlignedOverlayResult:
output_path: str
num_ca_pairs: int
rmsd_before: float
rmsd_after: float
backbone_path: str
folded_path: str
def _resolve_existing_file(path_value: str, label: str) -> Path:
candidate = Path(path_value).expanduser().resolve()
if not candidate.exists() or not candidate.is_file():
raise AlignmentError(f"{label} PDB does not exist: {candidate}")
return candidate
def _extract_ca_coords(structure: Structure) -> np.ndarray:
ca_coords: List[np.ndarray] = []
model = next(structure.get_models(), None)
if model is None:
return np.zeros((0, 3), dtype=np.float64)
for chain in model:
for residue in chain:
if "CA" not in residue:
continue
atom = residue["CA"]
ca_coords.append(np.asarray(atom.coord, dtype=np.float64))
if not ca_coords:
return np.zeros((0, 3), dtype=np.float64)
return np.stack(ca_coords, axis=0)
def _rmsd(a_coords: np.ndarray, b_coords: np.ndarray) -> float:
if a_coords.shape != b_coords.shape:
raise AlignmentError(f"RMSD shape mismatch: {a_coords.shape} vs {b_coords.shape}")
deltas = a_coords - b_coords
return float(np.sqrt(np.mean(np.sum(deltas * deltas, axis=-1))))
def _next_chain_id(used_ids: Iterable[str]) -> str:
used = set(used_ids)
for candidate in _CHAIN_POOL:
if candidate not in used:
return candidate
raise AlignmentError("No available chain identifiers left for aligned overlay.")
def _reassign_folded_chain_ids(folded_structure: Structure, backbone_chain_ids: Iterable[str]) -> Dict[str, str]:
model = next(folded_structure.get_models(), None)
if model is None:
raise AlignmentError("Folded structure has no model entries.")
used_ids = set(backbone_chain_ids)
remap: Dict[str, str] = {}
for chain in model:
chain_obj = chain # type: Chain
old_id = str(chain_obj.id)
new_id = old_id
if old_id in used_ids or len(old_id) != 1:
new_id = _next_chain_id(used_ids)
chain_obj.id = new_id
used_ids.add(new_id)
remap[old_id] = new_id
return remap
def _serialize_structure(structure: Structure) -> str:
writer = PDBIO()
writer.set_structure(structure)
buffer = io.StringIO()
writer.save(buffer)
return buffer.getvalue()
def _strip_pdb_terminal_lines(pdb_text: str) -> List[str]:
return [line for line in pdb_text.splitlines() if not line.startswith("END")]
def _apply_transform(structure: Structure, rotation: np.ndarray, translation: np.ndarray) -> None:
model = next(structure.get_models(), None)
if model is None:
raise AlignmentError("Cannot transform structure without models.")
for atom in model.get_atoms():
source = np.asarray(atom.coord, dtype=np.float64).reshape(3, 1)
transformed = (rotation @ source + translation).reshape(3)
atom.coord = transformed.astype(np.float32)
def align_folded_to_backbone_overlay(
backbone_sample_path: str,
folded_sample_path: str,
) -> AlignedOverlayResult:
backbone_path = _resolve_existing_file(backbone_sample_path, label="Backbone")
folded_path = _resolve_existing_file(folded_sample_path, label="Folded")
parser = PDBParser(QUIET=True)
backbone_structure = parser.get_structure("backbone", str(backbone_path))
folded_structure = parser.get_structure("folded", str(folded_path))
backbone_ca = _extract_ca_coords(backbone_structure)
folded_ca = _extract_ca_coords(folded_structure)
pair_count = min(backbone_ca.shape[0], folded_ca.shape[0])
if pair_count < 3:
raise AlignmentError(
"Need at least 3 CA pairs for alignment. "
f"backbone={backbone_ca.shape[0]}, folded={folded_ca.shape[0]}."
)
backbone_subset = backbone_ca[:pair_count]
folded_subset = folded_ca[:pair_count]
rmsd_before = _rmsd(folded_subset, backbone_subset)
aligned_subset, rotation, translation, _ = du.rigid_transform_3D(folded_subset, backbone_subset)
rmsd_after = _rmsd(aligned_subset, backbone_subset)
_apply_transform(folded_structure, rotation=rotation, translation=translation)
backbone_model = next(backbone_structure.get_models(), None)
if backbone_model is None:
raise AlignmentError("Backbone structure has no model entries.")
_reassign_folded_chain_ids(
folded_structure=folded_structure,
backbone_chain_ids=[str(chain.id) for chain in backbone_model],
)
overlay_dir = backbone_path.parent / "viewer_alignment"
overlay_dir.mkdir(parents=True, exist_ok=True)
overlay_name = (
f"{backbone_path.stem}__vs__{folded_path.stem}__aligned_"
f"{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.pdb"
)
overlay_path = overlay_dir / overlay_name
backbone_lines = _strip_pdb_terminal_lines(_serialize_structure(backbone_structure))
folded_lines = _strip_pdb_terminal_lines(_serialize_structure(folded_structure))
output_lines = [
"REMARK 900 FlowProt aligned compare overlay",
f"REMARK 900 backbone_source={backbone_path.as_posix()}",
f"REMARK 900 folded_source={folded_path.as_posix()}",
f"REMARK 900 ca_pairs={pair_count}",
f"REMARK 900 rmsd_before={rmsd_before:.6f}",
f"REMARK 900 rmsd_after={rmsd_after:.6f}",
*backbone_lines,
"TER",
*folded_lines,
"END",
]
overlay_path.write_text("\n".join(output_lines) + "\n", encoding="utf-8")
LOGGER.info("Created aligned overlay file at %s", overlay_path)
return AlignedOverlayResult(
output_path=overlay_path.as_posix(),
num_ca_pairs=pair_count,
rmsd_before=rmsd_before,
rmsd_after=rmsd_after,
backbone_path=backbone_path.as_posix(),
folded_path=folded_path.as_posix(),
)
|