# Copyright 2025 Dhruv Nair. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional import numpy as np import torch from atomworks.io.utils.io_utils import to_cif_file from biotite.structure import AtomArray, AtomArrayStack, stack from diffusers.utils import logging from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState from diffusers.modular_pipelines.modular_pipeline_utils import InputParam, OutputParam logger = logging.get_logger(__name__) AA_NAMES = [ "ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", ] def _build_atom_array(xyz: torch.Tensor, seq: Optional[torch.Tensor] = None) -> AtomArray: """ Build a biotite AtomArray from CA coordinates and optional sequence. Args: xyz: Coordinates for a single structure [L, 3]. seq: Sequence indices [L] (indexes into AA_NAMES). """ xyz_np = xyz.detach().cpu().float().numpy() L = xyz_np.shape[0] arr = AtomArray(L) arr.coord = xyz_np arr.atom_name = np.full(L, "CA") arr.element = np.full(L, "C") arr.chain_id = np.full(L, "A") arr.res_id = np.arange(1, L + 1) if seq is not None: seq_np = seq.detach().cpu().numpy() arr.res_name = np.array([ AA_NAMES[int(idx)] if int(idx) < len(AA_NAMES) else "UNK" for idx in seq_np ]) else: arr.res_name = np.full(L, "ALA") return arr def _build_atom_array_stack( xyz: torch.Tensor, seq: Optional[torch.Tensor] = None, ) -> AtomArrayStack: """ Build an AtomArrayStack from batched coordinates [B, L, 3]. Matches foundry ``build_stack_from_atom_array_and_batched_coords``. """ template = _build_atom_array(xyz[0], seq[0] if seq is not None else None) B = xyz.shape[0] arr_stack = stack([template for _ in range(B)]) arr_stack.coord = xyz.detach().cpu().float().numpy() return arr_stack def _build_trajectory_stack( trajectory: List[torch.Tensor], seq: Optional[torch.Tensor] = None, ) -> AtomArrayStack: """ Build an AtomArrayStack from a denoising trajectory. Each entry is [B, L, 3]; takes the first batch element per step. """ coords = torch.stack([t[0] for t in trajectory]) # [N_steps, L, 3] template = _build_atom_array(coords[0], seq[0] if seq is not None else None) arr_stack = stack([template for _ in range(coords.shape[0])]) arr_stack.coord = coords.detach().cpu().float().numpy() return arr_stack @dataclass class RFDiffusionPipelineOutput: """Output class for RFDiffusion pipeline.""" xyz: torch.Tensor atom_array: Optional[AtomArray] = None atom_array_stack: Optional[AtomArrayStack] = None trajectory_stack: Optional[AtomArrayStack] = None sequence_indices: Optional[torch.Tensor] = None sequence_logits: Optional[torch.Tensor] = None single: Optional[torch.Tensor] = None pair: Optional[torch.Tensor] = None pdb_string: Optional[str] = None trajectory: Optional[List[torch.Tensor]] = None class RFDiffusionDecodeStep(ModularPipelineBlocks): """ Decode step for RFDiffusion. Converts denoised coordinates to final output format. Supported ``output_type`` values: - ``"tensor"`` — raw tensors only - ``"pdb"`` — tensors + PDB format string - ``"cif"`` — tensors + AtomArray via AtomWorks, writes ``.cif`` - ``"cif.gz"`` — same as ``"cif"`` but compressed """ model_name = "rfdiffusion" @property def description(self) -> str: return ( "Decode step that converts denoised coordinates to final output, " "supporting tensor, PDB, and CIF (via AtomWorks) formats." ) @property def inputs(self) -> List[InputParam]: return [ InputParam( "output_type", default="tensor", type_hint=str, description="Output format: 'tensor', 'pdb', 'cif', or 'cif.gz'", ), InputParam( "output_path", type_hint=str, description="Path to save output structure", ), InputParam("xyz", required=True, type_hint=torch.Tensor, description="Denoised coordinates [B, L, 3]"), InputParam("sequence_indices", type_hint=torch.Tensor, description="Predicted sequence [B, L]"), InputParam("sequence_logits", type_hint=torch.Tensor, description="Sequence logits [B, L, n_aa]"), InputParam("single", type_hint=torch.Tensor, description="Single representation"), InputParam("pair", type_hint=torch.Tensor, description="Pair representation"), InputParam("trajectory", type_hint=List[torch.Tensor], description="Denoising trajectory"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("output", type_hint=RFDiffusionPipelineOutput, description="Final pipeline output"), ] @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) xyz = block_state.xyz sequence_indices = block_state.sequence_indices sequence_logits = block_state.sequence_logits single = block_state.single pair = block_state.pair trajectory = block_state.trajectory output_type = block_state.output_type or "tensor" output_path = block_state.output_path pdb_string = None atom_array = None atom_array_stack = None trajectory_stack = None # Build AtomArray for CIF output types if output_type in ("cif", "cif.gz"): atom_array = _build_atom_array( xyz[0], sequence_indices[0] if sequence_indices is not None else None, ) if xyz.shape[0] > 1: atom_array_stack = _build_atom_array_stack(xyz, sequence_indices) if trajectory: trajectory_stack = _build_trajectory_stack(trajectory, sequence_indices) # Build PDB string if output_type == "pdb": pdb_string = self._coords_to_pdb( xyz.squeeze(0), sequence_indices.squeeze(0) if sequence_indices is not None else None, ) # Write to disk if output_path is not None: import os os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) if output_type in ("cif", "cif.gz"): to_write = atom_array_stack if atom_array_stack is not None else atom_array base = output_path.rsplit(".", 1)[0] if "." in output_path else output_path to_cif_file(to_write, base, file_type=output_type, include_entity_poly=False) if trajectory_stack is not None: to_cif_file(trajectory_stack, base + "_trajectory", file_type=output_type, include_entity_poly=False) elif output_type == "pdb" and pdb_string is not None: with open(output_path, "w") as f: f.write(pdb_string) output = RFDiffusionPipelineOutput( xyz=xyz, atom_array=atom_array, atom_array_stack=atom_array_stack, trajectory_stack=trajectory_stack, sequence_indices=sequence_indices, sequence_logits=sequence_logits, single=single, pair=pair, pdb_string=pdb_string, trajectory=trajectory, ) block_state.output = output self.set_block_state(state, block_state) return components, state def _coords_to_pdb( self, xyz: torch.Tensor, seq: Optional[torch.Tensor] = None, ) -> str: """Convert coordinates to PDB format string.""" xyz_np = xyz.cpu().numpy() L = xyz_np.shape[0] if seq is not None: seq_np = seq.cpu().numpy() else: seq_np = None lines = [] atom_idx = 1 for i in range(L): if seq_np is not None: aa_idx = int(seq_np[i]) aa_name = AA_NAMES[aa_idx] if aa_idx < len(AA_NAMES) else "UNK" else: aa_name = "ALA" if xyz_np.ndim == 2: x, y, z = xyz_np[i, :] line = ( f"ATOM {atom_idx:5d} CA {aa_name:3s} A" f"{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C " ) lines.append(line) atom_idx += 1 else: for j, atom_name in enumerate(["N", "CA", "C"]): if j >= xyz_np.shape[1]: break x, y, z = xyz_np[i, j, :] line = ( f"ATOM {atom_idx:5d} {atom_name:<3s} {aa_name:3s} A" f"{i+1:4d} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 " f"{atom_name[0]:>2s} " ) lines.append(line) atom_idx += 1 lines.append("END") return "\n".join(lines)