| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
| from typing import List, Optional |
|
|
| import torch |
|
|
| from diffusers.utils import logging |
| from diffusers.modular_pipelines import ( |
| AutoPipelineBlocks, |
| ModularPipeline, |
| ModularPipelineBlocks, |
| PipelineState, |
| SequentialPipelineBlocks, |
| ) |
| from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, InsertableDict, OutputParam |
|
|
| from .before_denoise import ( |
| RFDiffusionInputStep, |
| RFDiffusionPrepareLatentsStep, |
| RFDiffusionSetTimestepsStep, |
| ) |
| from .decoders import RFDiffusionDecodeStep |
| from .denoise import RFDiffusionDenoiseStep |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
|
|
| THREE_TO_ONE = { |
| "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C", |
| "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I", |
| "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P", |
| "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V", |
| "UNK": "X", |
| } |
|
|
| AA_NAMES = list(THREE_TO_ONE.keys()) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class RFDiffusionBeforeDenoiseStep(SequentialPipelineBlocks): |
| """Sequential block for pre-denoising preparation.""" |
|
|
| block_classes = [ |
| RFDiffusionInputStep, |
| RFDiffusionSetTimestepsStep, |
| RFDiffusionPrepareLatentsStep, |
| ] |
| block_names = ["input", "set_timesteps", "prepare_latents"] |
|
|
| @property |
| def description(self): |
| return ( |
| "Before denoise step that prepares the inputs for the denoise step.\n" |
| "This is a sequential pipeline blocks:\n" |
| " - `RFDiffusionInputStep` processes contigs and prepares input features\n" |
| " - `RFDiffusionSetTimestepsStep` sets up the diffusion timesteps\n" |
| " - `RFDiffusionPrepareLatentsStep` initializes noised coordinates\n" |
| ) |
|
|
|
|
| class RFDiffusionAutoBeforeDenoiseStep(AutoPipelineBlocks): |
| """Auto-select before denoise step based on task.""" |
|
|
| block_classes = [RFDiffusionBeforeDenoiseStep] |
| block_names = ["unconditional"] |
| block_trigger_inputs = [None] |
|
|
| @property |
| def description(self): |
| return ( |
| "Before denoise step that prepares the inputs for the denoise step.\n" |
| "This is an auto pipeline block for protein structure generation.\n" |
| " - `RFDiffusionBeforeDenoiseStep` (unconditional) is used.\n" |
| ) |
|
|
|
|
| class RFDiffusionAutoDenoiseStep(AutoPipelineBlocks): |
| """Auto-select denoise step.""" |
|
|
| block_classes = [RFDiffusionDenoiseStep] |
| block_names = ["denoise"] |
| block_trigger_inputs = [None] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Denoise step that iteratively denoises the protein structure. " |
| "This is an auto pipeline block for protein structure generation. " |
| " - `RFDiffusionDenoiseStep` (denoise) for structure generation." |
| ) |
|
|
|
|
| class RFDiffusionAutoDecodeStep(AutoPipelineBlocks): |
| """Auto-select decode step.""" |
|
|
| block_classes = [RFDiffusionDecodeStep] |
| block_names = ["decode"] |
| block_trigger_inputs = [None] |
|
|
| @property |
| def description(self): |
| return "Decode step that converts denoised coordinates to PDB output.\n - `RFDiffusionDecodeStep`" |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class MPNNPipelineOutput: |
| """Output from ProteinMPNN / LigandMPNN sequence design.""" |
|
|
| designed_sequence: str |
| sequence_indices: torch.Tensor |
| sequence_logits: torch.Tensor |
| xyz: torch.Tensor |
| pdb_string: Optional[str] = None |
| sequence_recovery: Optional[float] = None |
|
|
|
|
| class MPNNSequenceDesignStep(ModularPipelineBlocks): |
| """ |
| Design sequences for a protein backbone using ProteinMPNN / LigandMPNN. |
| |
| Takes ``xyz`` coordinates (typically from an upstream RFDiffusion denoise |
| step) and runs the ``MPNNModel`` to produce amino acid sequences for |
| the designable regions. |
| |
| When no ``mpnn`` component is loaded, falls back to using the sequence |
| predictions from upstream RFDiffusion (or glycine everywhere). |
| """ |
|
|
| model_name = "mpnn" |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Design amino acid sequences for protein backbones using " |
| "ProteinMPNN or LigandMPNN. Accepts structure coordinates " |
| "from an upstream diffusion step." |
| ) |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec("mpnn", description="MPNNModel (ProteinMPNN or LigandMPNN)"), |
| ] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "xyz", required=True, type_hint=torch.Tensor, |
| description="Protein backbone coordinates [B, L, 3] (CA atoms)", |
| ), |
| InputParam( |
| "motif_mask", type_hint=torch.Tensor, |
| description="Mask for fixed/motif positions [L]. True = fixed sequence.", |
| ), |
| InputParam( |
| "sequence_indices", type_hint=torch.Tensor, |
| description="Known sequence indices for motif positions [B, L] (from RFDiffusion)", |
| ), |
| InputParam( |
| "temperature", default=0.1, type_hint=float, |
| description="Sampling temperature (lower = more deterministic)", |
| ), |
| InputParam( |
| "num_designs", default=1, type_hint=int, |
| description="Number of sequence designs to generate per structure", |
| ), |
| InputParam( |
| "output_type", default="tensor", type_hint=str, |
| description="'tensor', 'pdb', or 'cif'", |
| ), |
| InputParam( |
| "output_path", type_hint=str, |
| description="Path to save designed PDB", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "mpnn_output", type_hint=MPNNPipelineOutput, |
| description="MPNN sequence design output", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| xyz = block_state.xyz |
| motif_mask = block_state.motif_mask |
| known_seq = block_state.sequence_indices |
| temperature = block_state.temperature or 0.1 |
| output_type = block_state.output_type or "tensor" |
| output_path = block_state.output_path |
|
|
| B, L, _ = xyz.shape |
| device = xyz.device |
|
|
| has_mpnn = hasattr(components, "mpnn") and components.mpnn is not None |
|
|
| if has_mpnn: |
| sequence_logits, sequence_indices = self._run_mpnn( |
| components.mpnn, xyz, motif_mask, known_seq, temperature, |
| ) |
| else: |
| if known_seq is not None: |
| sequence_indices = known_seq |
| else: |
| sequence_indices = torch.full((B, L), 7, dtype=torch.long, device=device) |
| sequence_logits = torch.zeros(B, L, len(AA_NAMES), device=device) |
| sequence_logits.scatter_(2, sequence_indices.unsqueeze(-1), 1.0) |
|
|
| seq_list = sequence_indices[0].cpu().tolist() |
| designed_sequence = "".join( |
| THREE_TO_ONE.get(AA_NAMES[min(idx, len(AA_NAMES) - 1)], "X") |
| for idx in seq_list |
| ) |
|
|
| pdb_string = None |
| if output_type in ("pdb",): |
| pdb_string = self._coords_to_pdb(xyz[0], sequence_indices[0]) |
| if output_path: |
| import os |
| os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) |
| with open(output_path, "w") as f: |
| f.write(pdb_string) |
|
|
| output = MPNNPipelineOutput( |
| designed_sequence=designed_sequence, |
| sequence_indices=sequence_indices, |
| sequence_logits=sequence_logits, |
| xyz=xyz, |
| pdb_string=pdb_string, |
| ) |
|
|
| block_state.mpnn_output = output |
| self.set_block_state(state, block_state) |
| return components, state |
|
|
| def _run_mpnn(self, mpnn, xyz, motif_mask, known_seq, temperature): |
| """Run the MPNNModel wrapper on backbone coordinates.""" |
| B, L, _ = xyz.shape |
| device = xyz.device |
| dtype = xyz.dtype |
|
|
| ca = xyz |
| n_offset = torch.tensor([-1.458, 0.0, 0.0], device=device, dtype=dtype) |
| c_offset = torch.tensor([0.550, 1.424, 0.0], device=device, dtype=dtype) |
| o_offset = torch.tensor([0.550, 2.500, 0.0], device=device, dtype=dtype) |
|
|
| X = torch.stack([ |
| ca + n_offset, ca, ca + c_offset, ca + o_offset, |
| ], dim=2) |
|
|
| if motif_mask is not None: |
| designed_mask = ~motif_mask.unsqueeze(0).expand(B, -1) |
| else: |
| designed_mask = None |
|
|
| output = mpnn( |
| X=X, S=known_seq, designed_residue_mask=designed_mask, temperature=temperature, |
| ) |
|
|
| logits = output.sequence_logits |
| indices = output.sequence_indices |
|
|
| if motif_mask is not None and known_seq is not None: |
| indices[:, motif_mask] = known_seq[:, motif_mask] |
|
|
| return logits, indices |
|
|
| def _coords_to_pdb(self, xyz: torch.Tensor, seq: torch.Tensor) -> str: |
| xyz_np = xyz.cpu().numpy() |
| seq_np = seq.cpu().numpy() |
| L = xyz_np.shape[0] |
| lines = [] |
| for i in range(L): |
| aa_idx = int(seq_np[i]) |
| aa_name = AA_NAMES[min(aa_idx, len(AA_NAMES) - 1)] |
| x, y, z = xyz_np[i, :] |
| lines.append( |
| f"ATOM {i+1:5d} CA {aa_name:3s} A{i+1:4d} " |
| f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 C " |
| ) |
| lines.append("END") |
| return "\n".join(lines) |
|
|
|
|
| class MPNNAutoDesignStep(AutoPipelineBlocks): |
| """Auto-select MPNN design step.""" |
|
|
| block_classes = [MPNNSequenceDesignStep] |
| block_names = ["sequence_design"] |
| block_trigger_inputs = [None] |
|
|
| @property |
| def description(self) -> str: |
| return "Sequence design using ProteinMPNN or LigandMPNN." |
|
|
|
|
| |
| |
| |
|
|
|
|
| class RFDiffusionAutoBlocks(SequentialPipelineBlocks): |
| """ |
| Full protein design pipeline: RFDiffusion3 + optional ProteinMPNN/LigandMPNN. |
| |
| The active workflow is selected by trigger inputs: |
| - ``contigs`` only β structure generation |
| - ``contigs`` + ``temperature`` β structure + sequence design |
| - ``contigs`` + ``input_xyz`` + ``temperature`` β motif-conditioned + sequence design |
| |
| The MPNN step is skipped when ``temperature`` is not provided or when |
| no ``mpnn`` component is loaded. |
| """ |
|
|
| block_classes = [ |
| RFDiffusionAutoBeforeDenoiseStep, |
| RFDiffusionAutoDenoiseStep, |
| RFDiffusionAutoDecodeStep, |
| MPNNAutoDesignStep, |
| ] |
| block_names = [ |
| "before_denoise", |
| "denoise", |
| "decoder", |
| "sequence_design", |
| ] |
|
|
| _workflow_map = { |
| "structure_only": { |
| "contigs": True, |
| }, |
| "structure_and_sequence": { |
| "contigs": True, |
| "temperature": True, |
| }, |
| "motif_structure_and_sequence": { |
| "contigs": True, |
| "input_xyz": True, |
| "temperature": True, |
| }, |
| } |
|
|
| @property |
| def description(self): |
| return ( |
| "Modular pipeline for protein design using RFDiffusion3.\n" |
| "Workflows:\n" |
| " - structure_only: backbone generation\n" |
| " - structure_and_sequence: backbone + MPNN sequence design\n" |
| " - motif_structure_and_sequence: motif-conditioned + MPNN\n" |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| UNCONDITIONAL_BLOCKS = InsertableDict( |
| [ |
| ("input", RFDiffusionInputStep), |
| ("set_timesteps", RFDiffusionSetTimestepsStep), |
| ("prepare_latents", RFDiffusionPrepareLatentsStep), |
| ("denoise", RFDiffusionDenoiseStep), |
| ("decode", RFDiffusionDecodeStep), |
| ("sequence_design", MPNNSequenceDesignStep), |
| ] |
| ) |
|
|
| AUTO_BLOCKS = InsertableDict( |
| [ |
| ("before_denoise", RFDiffusionAutoBeforeDenoiseStep), |
| ("denoise", RFDiffusionAutoDenoiseStep), |
| ("decode", RFDiffusionAutoDecodeStep), |
| ("sequence_design", MPNNAutoDesignStep), |
| ] |
| ) |
|
|
| ALL_BLOCKS = { |
| "unconditional": UNCONDITIONAL_BLOCKS, |
| "auto": AUTO_BLOCKS, |
| } |
|
|