# Copyright 2025 Dhruv Nair. All rights reserved. # Licensed under the Apache License, Version 2.0 """ Decode step for RF3 — converts denoised coordinates to output structures. Supports tensor, PDB, and CIF (via AtomWorks) output formats. """ from dataclasses import dataclass from typing import List, Optional import numpy as np import torch from atomworks.io.utils.io_utils import to_cif_file from biotite.structure import AtomArray, AtomArrayStack, stack from diffusers.utils import logging from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam logger = logging.get_logger(__name__) AA_ORDER = "ARNDCQEGHILKMFPSTWYV" AA_NAMES_3 = [ "ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", ] def _build_atom_array(xyz: torch.Tensor, sequence: Optional[str] = None) -> AtomArray: xyz_np = xyz.detach().cpu().float().numpy() L = xyz_np.shape[0] arr = AtomArray(L) arr.coord = xyz_np arr.atom_name = np.full(L, "CA") arr.element = np.full(L, "C") arr.chain_id = np.full(L, "A") arr.res_id = np.arange(1, L + 1) if sequence: arr.res_name = np.array([ AA_NAMES_3[AA_ORDER.find(aa)] if aa in AA_ORDER else "UNK" for aa in sequence ]) else: arr.res_name = np.full(L, "ALA") return arr def _build_atom_array_stack(xyz: torch.Tensor, sequence: Optional[str] = None) -> AtomArrayStack: template = _build_atom_array(xyz[0], sequence) arr_stack = stack([template for _ in range(xyz.shape[0])]) arr_stack.coord = xyz.detach().cpu().float().numpy() return arr_stack @dataclass class RF3PipelineOutput: """Output class for RF3 pipeline.""" xyz: torch.Tensor atom_array: Optional[AtomArray] = None atom_array_stack: Optional[AtomArrayStack] = None trajectory_stack: Optional[AtomArrayStack] = None distogram: Optional[torch.Tensor] = None sequence: Optional[str] = None pdb_string: Optional[str] = None trajectory: Optional[List[torch.Tensor]] = None class RF3DecodeStep(ModularPipelineBlocks): """ Decode step for RF3. Supported ``output_type`` values: ``"tensor"``, ``"pdb"``, ``"cif"``, ``"cif.gz"``. """ model_name = "rf3" @property def description(self) -> str: return "Convert predicted coordinates to output format (tensor/PDB/CIF)." @property def inputs(self) -> List[InputParam]: return [ InputParam("output_type", default="tensor", type_hint=str), InputParam("output_path", type_hint=str), InputParam("xyz", required=True, type_hint=torch.Tensor), InputParam("sequence", type_hint=str), InputParam("distogram", type_hint=torch.Tensor), InputParam("trajectory", type_hint=List[torch.Tensor]), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("output", type_hint=RF3PipelineOutput), ] @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) xyz = block_state.xyz sequence = block_state.sequence distogram = block_state.distogram trajectory = block_state.trajectory output_type = block_state.output_type or "tensor" output_path = block_state.output_path pdb_string = None atom_array = None atom_array_stack = None trajectory_stack = None if output_type in ("cif", "cif.gz"): atom_array = _build_atom_array(xyz[0], sequence) if xyz.shape[0] > 1: atom_array_stack = _build_atom_array_stack(xyz, sequence) if trajectory: traj_coords = torch.stack([t[0] for t in trajectory]) template = _build_atom_array(traj_coords[0], sequence) trajectory_stack = stack([template for _ in range(traj_coords.shape[0])]) trajectory_stack.coord = traj_coords.detach().cpu().float().numpy() if output_type == "pdb": pdb_string = self._coords_to_pdb(xyz[0], sequence) if output_path is not None: import os os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) if output_type in ("cif", "cif.gz"): to_write = atom_array_stack if atom_array_stack is not None else atom_array base = output_path.rsplit(".", 1)[0] if "." in output_path else output_path to_cif_file(to_write, base, file_type=output_type, include_entity_poly=False) elif output_type == "pdb" and pdb_string: with open(output_path, "w") as f: f.write(pdb_string) output = RF3PipelineOutput( xyz=xyz, atom_array=atom_array, atom_array_stack=atom_array_stack, trajectory_stack=trajectory_stack, distogram=distogram, sequence=sequence, pdb_string=pdb_string, trajectory=trajectory, ) block_state.output = output self.set_block_state(state, block_state) return components, state def _coords_to_pdb(self, xyz: torch.Tensor, sequence: Optional[str] = None) -> str: xyz_np = xyz.cpu().numpy() L = xyz_np.shape[0] lines = [] for i in range(L): aa = sequence[i] if sequence and i < len(sequence) else "A" aa3 = AA_NAMES_3[AA_ORDER.find(aa)] if aa in AA_ORDER else "UNK" x, y, z = xyz_np[i, :] lines.append( f"ATOM {i+1:5d} CA {aa3:3s} A{i+1:4d} " f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C " ) lines.append("END") return "\n".join(lines)