| """Alignment helpers for overlaying generated backbone and folded structures."""
|
|
|
| from __future__ import annotations
|
|
|
| import io
|
| import logging
|
| import string
|
| from dataclasses import dataclass
|
| from datetime import datetime
|
| from pathlib import Path
|
| from typing import Dict, Iterable, List, Tuple
|
|
|
| import numpy as np
|
| from Bio.PDB import PDBIO, PDBParser
|
| from Bio.PDB.Chain import Chain
|
| from Bio.PDB.Structure import Structure
|
|
|
| from utils import new_pdbUtils as du
|
|
|
| LOGGER = logging.getLogger(__name__)
|
| _CHAIN_POOL = string.ascii_uppercase + string.digits + string.ascii_lowercase
|
|
|
|
|
| class AlignmentError(RuntimeError):
|
| """Raised when aligned-compare preparation fails."""
|
|
|
|
|
| @dataclass
|
| class AlignedOverlayResult:
|
| output_path: str
|
| num_ca_pairs: int
|
| rmsd_before: float
|
| rmsd_after: float
|
| backbone_path: str
|
| folded_path: str
|
|
|
|
|
| def _resolve_existing_file(path_value: str, label: str) -> Path:
|
| candidate = Path(path_value).expanduser().resolve()
|
| if not candidate.exists() or not candidate.is_file():
|
| raise AlignmentError(f"{label} PDB does not exist: {candidate}")
|
| return candidate
|
|
|
|
|
| def _extract_ca_coords(structure: Structure) -> np.ndarray:
|
| ca_coords: List[np.ndarray] = []
|
| model = next(structure.get_models(), None)
|
| if model is None:
|
| return np.zeros((0, 3), dtype=np.float64)
|
| for chain in model:
|
| for residue in chain:
|
| if "CA" not in residue:
|
| continue
|
| atom = residue["CA"]
|
| ca_coords.append(np.asarray(atom.coord, dtype=np.float64))
|
| if not ca_coords:
|
| return np.zeros((0, 3), dtype=np.float64)
|
| return np.stack(ca_coords, axis=0)
|
|
|
|
|
| def _rmsd(a_coords: np.ndarray, b_coords: np.ndarray) -> float:
|
| if a_coords.shape != b_coords.shape:
|
| raise AlignmentError(f"RMSD shape mismatch: {a_coords.shape} vs {b_coords.shape}")
|
| deltas = a_coords - b_coords
|
| return float(np.sqrt(np.mean(np.sum(deltas * deltas, axis=-1))))
|
|
|
|
|
| def _next_chain_id(used_ids: Iterable[str]) -> str:
|
| used = set(used_ids)
|
| for candidate in _CHAIN_POOL:
|
| if candidate not in used:
|
| return candidate
|
| raise AlignmentError("No available chain identifiers left for aligned overlay.")
|
|
|
|
|
| def _reassign_folded_chain_ids(folded_structure: Structure, backbone_chain_ids: Iterable[str]) -> Dict[str, str]:
|
| model = next(folded_structure.get_models(), None)
|
| if model is None:
|
| raise AlignmentError("Folded structure has no model entries.")
|
| used_ids = set(backbone_chain_ids)
|
| remap: Dict[str, str] = {}
|
| for chain in model:
|
| chain_obj = chain
|
| old_id = str(chain_obj.id)
|
| new_id = old_id
|
| if old_id in used_ids or len(old_id) != 1:
|
| new_id = _next_chain_id(used_ids)
|
| chain_obj.id = new_id
|
| used_ids.add(new_id)
|
| remap[old_id] = new_id
|
| return remap
|
|
|
|
|
| def _serialize_structure(structure: Structure) -> str:
|
| writer = PDBIO()
|
| writer.set_structure(structure)
|
| buffer = io.StringIO()
|
| writer.save(buffer)
|
| return buffer.getvalue()
|
|
|
|
|
| def _strip_pdb_terminal_lines(pdb_text: str) -> List[str]:
|
| return [line for line in pdb_text.splitlines() if not line.startswith("END")]
|
|
|
|
|
| def _apply_transform(structure: Structure, rotation: np.ndarray, translation: np.ndarray) -> None:
|
| model = next(structure.get_models(), None)
|
| if model is None:
|
| raise AlignmentError("Cannot transform structure without models.")
|
| for atom in model.get_atoms():
|
| source = np.asarray(atom.coord, dtype=np.float64).reshape(3, 1)
|
| transformed = (rotation @ source + translation).reshape(3)
|
| atom.coord = transformed.astype(np.float32)
|
|
|
|
|
| def align_folded_to_backbone_overlay(
|
| backbone_sample_path: str,
|
| folded_sample_path: str,
|
| ) -> AlignedOverlayResult:
|
| backbone_path = _resolve_existing_file(backbone_sample_path, label="Backbone")
|
| folded_path = _resolve_existing_file(folded_sample_path, label="Folded")
|
|
|
| parser = PDBParser(QUIET=True)
|
| backbone_structure = parser.get_structure("backbone", str(backbone_path))
|
| folded_structure = parser.get_structure("folded", str(folded_path))
|
|
|
| backbone_ca = _extract_ca_coords(backbone_structure)
|
| folded_ca = _extract_ca_coords(folded_structure)
|
| pair_count = min(backbone_ca.shape[0], folded_ca.shape[0])
|
| if pair_count < 3:
|
| raise AlignmentError(
|
| "Need at least 3 CA pairs for alignment. "
|
| f"backbone={backbone_ca.shape[0]}, folded={folded_ca.shape[0]}."
|
| )
|
|
|
| backbone_subset = backbone_ca[:pair_count]
|
| folded_subset = folded_ca[:pair_count]
|
| rmsd_before = _rmsd(folded_subset, backbone_subset)
|
| aligned_subset, rotation, translation, _ = du.rigid_transform_3D(folded_subset, backbone_subset)
|
| rmsd_after = _rmsd(aligned_subset, backbone_subset)
|
| _apply_transform(folded_structure, rotation=rotation, translation=translation)
|
|
|
| backbone_model = next(backbone_structure.get_models(), None)
|
| if backbone_model is None:
|
| raise AlignmentError("Backbone structure has no model entries.")
|
| _reassign_folded_chain_ids(
|
| folded_structure=folded_structure,
|
| backbone_chain_ids=[str(chain.id) for chain in backbone_model],
|
| )
|
|
|
| overlay_dir = backbone_path.parent / "viewer_alignment"
|
| overlay_dir.mkdir(parents=True, exist_ok=True)
|
| overlay_name = (
|
| f"{backbone_path.stem}__vs__{folded_path.stem}__aligned_"
|
| f"{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.pdb"
|
| )
|
| overlay_path = overlay_dir / overlay_name
|
|
|
| backbone_lines = _strip_pdb_terminal_lines(_serialize_structure(backbone_structure))
|
| folded_lines = _strip_pdb_terminal_lines(_serialize_structure(folded_structure))
|
| output_lines = [
|
| "REMARK 900 FlowProt aligned compare overlay",
|
| f"REMARK 900 backbone_source={backbone_path.as_posix()}",
|
| f"REMARK 900 folded_source={folded_path.as_posix()}",
|
| f"REMARK 900 ca_pairs={pair_count}",
|
| f"REMARK 900 rmsd_before={rmsd_before:.6f}",
|
| f"REMARK 900 rmsd_after={rmsd_after:.6f}",
|
| *backbone_lines,
|
| "TER",
|
| *folded_lines,
|
| "END",
|
| ]
|
| overlay_path.write_text("\n".join(output_lines) + "\n", encoding="utf-8")
|
| LOGGER.info("Created aligned overlay file at %s", overlay_path)
|
|
|
| return AlignedOverlayResult(
|
| output_path=overlay_path.as_posix(),
|
| num_ca_pairs=pair_count,
|
| rmsd_before=rmsd_before,
|
| rmsd_after=rmsd_after,
|
| backbone_path=backbone_path.as_posix(),
|
| folded_path=folded_path.as_posix(),
|
| )
|
|
|