"""Alignment helpers for overlaying generated backbone and folded structures.""" from __future__ import annotations import io import logging import string from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Dict, Iterable, List, Tuple import numpy as np from Bio.PDB import PDBIO, PDBParser from Bio.PDB.Chain import Chain from Bio.PDB.Structure import Structure from utils import new_pdbUtils as du LOGGER = logging.getLogger(__name__) _CHAIN_POOL = string.ascii_uppercase + string.digits + string.ascii_lowercase class AlignmentError(RuntimeError): """Raised when aligned-compare preparation fails.""" @dataclass class AlignedOverlayResult: output_path: str num_ca_pairs: int rmsd_before: float rmsd_after: float backbone_path: str folded_path: str def _resolve_existing_file(path_value: str, label: str) -> Path: candidate = Path(path_value).expanduser().resolve() if not candidate.exists() or not candidate.is_file(): raise AlignmentError(f"{label} PDB does not exist: {candidate}") return candidate def _extract_ca_coords(structure: Structure) -> np.ndarray: ca_coords: List[np.ndarray] = [] model = next(structure.get_models(), None) if model is None: return np.zeros((0, 3), dtype=np.float64) for chain in model: for residue in chain: if "CA" not in residue: continue atom = residue["CA"] ca_coords.append(np.asarray(atom.coord, dtype=np.float64)) if not ca_coords: return np.zeros((0, 3), dtype=np.float64) return np.stack(ca_coords, axis=0) def _rmsd(a_coords: np.ndarray, b_coords: np.ndarray) -> float: if a_coords.shape != b_coords.shape: raise AlignmentError(f"RMSD shape mismatch: {a_coords.shape} vs {b_coords.shape}") deltas = a_coords - b_coords return float(np.sqrt(np.mean(np.sum(deltas * deltas, axis=-1)))) def _next_chain_id(used_ids: Iterable[str]) -> str: used = set(used_ids) for candidate in _CHAIN_POOL: if candidate not in used: return candidate raise AlignmentError("No available chain identifiers left for aligned overlay.") def _reassign_folded_chain_ids(folded_structure: Structure, backbone_chain_ids: Iterable[str]) -> Dict[str, str]: model = next(folded_structure.get_models(), None) if model is None: raise AlignmentError("Folded structure has no model entries.") used_ids = set(backbone_chain_ids) remap: Dict[str, str] = {} for chain in model: chain_obj = chain # type: Chain old_id = str(chain_obj.id) new_id = old_id if old_id in used_ids or len(old_id) != 1: new_id = _next_chain_id(used_ids) chain_obj.id = new_id used_ids.add(new_id) remap[old_id] = new_id return remap def _serialize_structure(structure: Structure) -> str: writer = PDBIO() writer.set_structure(structure) buffer = io.StringIO() writer.save(buffer) return buffer.getvalue() def _strip_pdb_terminal_lines(pdb_text: str) -> List[str]: return [line for line in pdb_text.splitlines() if not line.startswith("END")] def _apply_transform(structure: Structure, rotation: np.ndarray, translation: np.ndarray) -> None: model = next(structure.get_models(), None) if model is None: raise AlignmentError("Cannot transform structure without models.") for atom in model.get_atoms(): source = np.asarray(atom.coord, dtype=np.float64).reshape(3, 1) transformed = (rotation @ source + translation).reshape(3) atom.coord = transformed.astype(np.float32) def align_folded_to_backbone_overlay( backbone_sample_path: str, folded_sample_path: str, ) -> AlignedOverlayResult: backbone_path = _resolve_existing_file(backbone_sample_path, label="Backbone") folded_path = _resolve_existing_file(folded_sample_path, label="Folded") parser = PDBParser(QUIET=True) backbone_structure = parser.get_structure("backbone", str(backbone_path)) folded_structure = parser.get_structure("folded", str(folded_path)) backbone_ca = _extract_ca_coords(backbone_structure) folded_ca = _extract_ca_coords(folded_structure) pair_count = min(backbone_ca.shape[0], folded_ca.shape[0]) if pair_count < 3: raise AlignmentError( "Need at least 3 CA pairs for alignment. " f"backbone={backbone_ca.shape[0]}, folded={folded_ca.shape[0]}." ) backbone_subset = backbone_ca[:pair_count] folded_subset = folded_ca[:pair_count] rmsd_before = _rmsd(folded_subset, backbone_subset) aligned_subset, rotation, translation, _ = du.rigid_transform_3D(folded_subset, backbone_subset) rmsd_after = _rmsd(aligned_subset, backbone_subset) _apply_transform(folded_structure, rotation=rotation, translation=translation) backbone_model = next(backbone_structure.get_models(), None) if backbone_model is None: raise AlignmentError("Backbone structure has no model entries.") _reassign_folded_chain_ids( folded_structure=folded_structure, backbone_chain_ids=[str(chain.id) for chain in backbone_model], ) overlay_dir = backbone_path.parent / "viewer_alignment" overlay_dir.mkdir(parents=True, exist_ok=True) overlay_name = ( f"{backbone_path.stem}__vs__{folded_path.stem}__aligned_" f"{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.pdb" ) overlay_path = overlay_dir / overlay_name backbone_lines = _strip_pdb_terminal_lines(_serialize_structure(backbone_structure)) folded_lines = _strip_pdb_terminal_lines(_serialize_structure(folded_structure)) output_lines = [ "REMARK 900 FlowProt aligned compare overlay", f"REMARK 900 backbone_source={backbone_path.as_posix()}", f"REMARK 900 folded_source={folded_path.as_posix()}", f"REMARK 900 ca_pairs={pair_count}", f"REMARK 900 rmsd_before={rmsd_before:.6f}", f"REMARK 900 rmsd_after={rmsd_after:.6f}", *backbone_lines, "TER", *folded_lines, "END", ] overlay_path.write_text("\n".join(output_lines) + "\n", encoding="utf-8") LOGGER.info("Created aligned overlay file at %s", overlay_path) return AlignedOverlayResult( output_path=overlay_path.as_posix(), num_ca_pairs=pair_count, rmsd_before=rmsd_before, rmsd_after=rmsd_after, backbone_path=backbone_path.as_posix(), folded_path=folded_path.as_posix(), )