# Copyright 2025 Dhruv Nair. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional import torch from diffusers.utils import logging from diffusers.modular_pipelines import ( AutoPipelineBlocks, ModularPipeline, ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks, ) from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, InsertableDict, OutputParam from .before_denoise import ( RFDiffusionInputStep, RFDiffusionPrepareLatentsStep, RFDiffusionSetTimestepsStep, ) from .decoders import RFDiffusionDecodeStep from .denoise import RFDiffusionDenoiseStep logger = logging.get_logger(__name__) # ─── Amino acid mappings (used by MPNN blocks) ───────────────────────── THREE_TO_ONE = { "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C", "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I", "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P", "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V", "UNK": "X", } AA_NAMES = list(THREE_TO_ONE.keys()) # ═══════════════════════════════════════════════════════════════════════════ # RFDiffusion blocks # ═══════════════════════════════════════════════════════════════════════════ class RFDiffusionBeforeDenoiseStep(SequentialPipelineBlocks): """Sequential block for pre-denoising preparation.""" block_classes = [ RFDiffusionInputStep, RFDiffusionSetTimestepsStep, RFDiffusionPrepareLatentsStep, ] block_names = ["input", "set_timesteps", "prepare_latents"] @property def description(self): return ( "Before denoise step that prepares the inputs for the denoise step.\n" "This is a sequential pipeline blocks:\n" " - `RFDiffusionInputStep` processes contigs and prepares input features\n" " - `RFDiffusionSetTimestepsStep` sets up the diffusion timesteps\n" " - `RFDiffusionPrepareLatentsStep` initializes noised coordinates\n" ) class RFDiffusionAutoBeforeDenoiseStep(AutoPipelineBlocks): """Auto-select before denoise step based on task.""" block_classes = [RFDiffusionBeforeDenoiseStep] block_names = ["unconditional"] block_trigger_inputs = [None] @property def description(self): return ( "Before denoise step that prepares the inputs for the denoise step.\n" "This is an auto pipeline block for protein structure generation.\n" " - `RFDiffusionBeforeDenoiseStep` (unconditional) is used.\n" ) class RFDiffusionAutoDenoiseStep(AutoPipelineBlocks): """Auto-select denoise step.""" block_classes = [RFDiffusionDenoiseStep] block_names = ["denoise"] block_trigger_inputs = [None] @property def description(self) -> str: return ( "Denoise step that iteratively denoises the protein structure. " "This is an auto pipeline block for protein structure generation. " " - `RFDiffusionDenoiseStep` (denoise) for structure generation." ) class RFDiffusionAutoDecodeStep(AutoPipelineBlocks): """Auto-select decode step.""" block_classes = [RFDiffusionDecodeStep] block_names = ["decode"] block_trigger_inputs = [None] @property def description(self): return "Decode step that converts denoised coordinates to PDB output.\n - `RFDiffusionDecodeStep`" # ═══════════════════════════════════════════════════════════════════════════ # MPNN blocks (defined before RFDiffusionAutoBlocks which references them) # ═══════════════════════════════════════════════════════════════════════════ @dataclass class MPNNPipelineOutput: """Output from ProteinMPNN / LigandMPNN sequence design.""" designed_sequence: str sequence_indices: torch.Tensor # [B, L] token indices sequence_logits: torch.Tensor # [B, L, n_vocab] logits xyz: torch.Tensor # [B, L, 3] input structure (passed through) pdb_string: Optional[str] = None # PDB with designed sequence sequence_recovery: Optional[float] = None class MPNNSequenceDesignStep(ModularPipelineBlocks): """ Design sequences for a protein backbone using ProteinMPNN / LigandMPNN. Takes ``xyz`` coordinates (typically from an upstream RFDiffusion denoise step) and runs the ``MPNNModel`` to produce amino acid sequences for the designable regions. When no ``mpnn`` component is loaded, falls back to using the sequence predictions from upstream RFDiffusion (or glycine everywhere). """ model_name = "mpnn" @property def description(self) -> str: return ( "Design amino acid sequences for protein backbones using " "ProteinMPNN or LigandMPNN. Accepts structure coordinates " "from an upstream diffusion step." ) @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("mpnn", description="MPNNModel (ProteinMPNN or LigandMPNN)"), ] @property def inputs(self) -> List[InputParam]: return [ InputParam( "xyz", required=True, type_hint=torch.Tensor, description="Protein backbone coordinates [B, L, 3] (CA atoms)", ), InputParam( "motif_mask", type_hint=torch.Tensor, description="Mask for fixed/motif positions [L]. True = fixed sequence.", ), InputParam( "sequence_indices", type_hint=torch.Tensor, description="Known sequence indices for motif positions [B, L] (from RFDiffusion)", ), InputParam( "temperature", default=0.1, type_hint=float, description="Sampling temperature (lower = more deterministic)", ), InputParam( "num_designs", default=1, type_hint=int, description="Number of sequence designs to generate per structure", ), InputParam( "output_type", default="tensor", type_hint=str, description="'tensor', 'pdb', or 'cif'", ), InputParam( "output_path", type_hint=str, description="Path to save designed PDB", ), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "mpnn_output", type_hint=MPNNPipelineOutput, description="MPNN sequence design output", ), ] @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) xyz = block_state.xyz motif_mask = block_state.motif_mask known_seq = block_state.sequence_indices temperature = block_state.temperature or 0.1 output_type = block_state.output_type or "tensor" output_path = block_state.output_path B, L, _ = xyz.shape device = xyz.device has_mpnn = hasattr(components, "mpnn") and components.mpnn is not None if has_mpnn: sequence_logits, sequence_indices = self._run_mpnn( components.mpnn, xyz, motif_mask, known_seq, temperature, ) else: if known_seq is not None: sequence_indices = known_seq else: sequence_indices = torch.full((B, L), 7, dtype=torch.long, device=device) # GLY sequence_logits = torch.zeros(B, L, len(AA_NAMES), device=device) sequence_logits.scatter_(2, sequence_indices.unsqueeze(-1), 1.0) seq_list = sequence_indices[0].cpu().tolist() designed_sequence = "".join( THREE_TO_ONE.get(AA_NAMES[min(idx, len(AA_NAMES) - 1)], "X") for idx in seq_list ) pdb_string = None if output_type in ("pdb",): pdb_string = self._coords_to_pdb(xyz[0], sequence_indices[0]) if output_path: import os os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) with open(output_path, "w") as f: f.write(pdb_string) output = MPNNPipelineOutput( designed_sequence=designed_sequence, sequence_indices=sequence_indices, sequence_logits=sequence_logits, xyz=xyz, pdb_string=pdb_string, ) block_state.mpnn_output = output self.set_block_state(state, block_state) return components, state def _run_mpnn(self, mpnn, xyz, motif_mask, known_seq, temperature): """Run the MPNNModel wrapper on backbone coordinates.""" B, L, _ = xyz.shape device = xyz.device dtype = xyz.dtype ca = xyz n_offset = torch.tensor([-1.458, 0.0, 0.0], device=device, dtype=dtype) c_offset = torch.tensor([0.550, 1.424, 0.0], device=device, dtype=dtype) o_offset = torch.tensor([0.550, 2.500, 0.0], device=device, dtype=dtype) X = torch.stack([ ca + n_offset, ca, ca + c_offset, ca + o_offset, ], dim=2) if motif_mask is not None: designed_mask = ~motif_mask.unsqueeze(0).expand(B, -1) else: designed_mask = None output = mpnn( X=X, S=known_seq, designed_residue_mask=designed_mask, temperature=temperature, ) logits = output.sequence_logits indices = output.sequence_indices if motif_mask is not None and known_seq is not None: indices[:, motif_mask] = known_seq[:, motif_mask] return logits, indices def _coords_to_pdb(self, xyz: torch.Tensor, seq: torch.Tensor) -> str: xyz_np = xyz.cpu().numpy() seq_np = seq.cpu().numpy() L = xyz_np.shape[0] lines = [] for i in range(L): aa_idx = int(seq_np[i]) aa_name = AA_NAMES[min(aa_idx, len(AA_NAMES) - 1)] x, y, z = xyz_np[i, :] lines.append( f"ATOM {i+1:5d} CA {aa_name:3s} A{i+1:4d} " f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C " ) lines.append("END") return "\n".join(lines) class MPNNAutoDesignStep(AutoPipelineBlocks): """Auto-select MPNN design step.""" block_classes = [MPNNSequenceDesignStep] block_names = ["sequence_design"] block_trigger_inputs = [None] @property def description(self) -> str: return "Sequence design using ProteinMPNN or LigandMPNN." # ═══════════════════════════════════════════════════════════════════════════ # Top-level pipeline blocks # ═══════════════════════════════════════════════════════════════════════════ class RFDiffusionAutoBlocks(SequentialPipelineBlocks): """ Full protein design pipeline: RFDiffusion3 + optional ProteinMPNN/LigandMPNN. The active workflow is selected by trigger inputs: - ``contigs`` only → structure generation - ``contigs`` + ``temperature`` → structure + sequence design - ``contigs`` + ``input_xyz`` + ``temperature`` → motif-conditioned + sequence design The MPNN step is skipped when ``temperature`` is not provided or when no ``mpnn`` component is loaded. """ block_classes = [ RFDiffusionAutoBeforeDenoiseStep, RFDiffusionAutoDenoiseStep, RFDiffusionAutoDecodeStep, MPNNAutoDesignStep, ] block_names = [ "before_denoise", "denoise", "decoder", "sequence_design", ] _workflow_map = { "structure_only": { "contigs": True, }, "structure_and_sequence": { "contigs": True, "temperature": True, }, "motif_structure_and_sequence": { "contigs": True, "input_xyz": True, "temperature": True, }, } @property def description(self): return ( "Modular pipeline for protein design using RFDiffusion3.\n" "Workflows:\n" " - structure_only: backbone generation\n" " - structure_and_sequence: backbone + MPNN sequence design\n" " - motif_structure_and_sequence: motif-conditioned + MPNN\n" ) # ═══════════════════════════════════════════════════════════════════════════ # Block registries # ═══════════════════════════════════════════════════════════════════════════ UNCONDITIONAL_BLOCKS = InsertableDict( [ ("input", RFDiffusionInputStep), ("set_timesteps", RFDiffusionSetTimestepsStep), ("prepare_latents", RFDiffusionPrepareLatentsStep), ("denoise", RFDiffusionDenoiseStep), ("decode", RFDiffusionDecodeStep), ("sequence_design", MPNNSequenceDesignStep), ] ) AUTO_BLOCKS = InsertableDict( [ ("before_denoise", RFDiffusionAutoBeforeDenoiseStep), ("denoise", RFDiffusionAutoDenoiseStep), ("decode", RFDiffusionAutoDecodeStep), ("sequence_design", MPNNAutoDesignStep), ] ) ALL_BLOCKS = { "unconditional": UNCONDITIONAL_BLOCKS, "auto": AUTO_BLOCKS, }