# Copyright 2025 Dhruv Nair. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Tuple, Union import torch from diffusers.utils import logging from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam logger = logging.get_logger(__name__) def parse_contig_string(contig_str: str) -> Tuple[int, List[Tuple[int, int]]]: """ Parse contig specification string. Supports formats like: - "100" -> 100 residues to design - "50-100" -> random length between 50-100 - "A10-25/50" -> motif from chain A residues 10-25, plus 50 designed Returns: total_length: Total protein length motif_ranges: List of (start, end) for motif residues (0-indexed) """ parts = contig_str.split("/") total_length = 0 motif_ranges = [] for part in parts: part = part.strip() if not part: continue if part[0].isalpha(): chain = part[0] residue_spec = part[1:] if "-" in residue_spec: start, end = map(int, residue_spec.split("-")) else: start = end = int(residue_spec) motif_len = end - start + 1 motif_ranges.append((total_length, total_length + motif_len)) total_length += motif_len else: if "-" in part: min_len, max_len = map(int, part.split("-")) add_len = (min_len + max_len) // 2 else: add_len = int(part) total_length += add_len return total_length, motif_ranges class RFDiffusionInputStep(ModularPipelineBlocks): """ Input processing step for RFDiffusion. Parses contigs to prepare features for structure generation. """ model_name = "rfdiffusion" @property def description(self) -> str: return ( "Input processing step that:\n" " 1. Parses contig specification to determine protein length and design regions\n" " 2. Generates masks for motif positions\n" ) @property def inputs(self) -> List[InputParam]: return [ InputParam( "contigs", required=True, type_hint=Union[str, List[str]], description="Contig specification defining design regions (e.g., '100' or 'A10-25/50-100')", ), InputParam( "input_xyz", type_hint=torch.Tensor, description="Input coordinates for motif residues [N_motif, 3]", ), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "motif_mask", type_hint=torch.Tensor, description="Boolean mask for motif (fixed) positions", ), OutputParam( "motif_xyz", type_hint=torch.Tensor, description="Coordinates for motif residues", ), OutputParam( "L", type_hint=int, description="Total length of the protein being designed", ), OutputParam( "batch_size", type_hint=int, description="Batch size (typically 1 for RFDiffusion)", ), OutputParam( "dtype", type_hint=torch.dtype, description="Data type for tensors", ), ] def check_inputs(self, components, block_state): if block_state.contigs is None: raise ValueError("`contigs` must be provided to specify protein design regions") @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) self.check_inputs(components, block_state) contigs = block_state.contigs input_xyz = block_state.input_xyz if isinstance(contigs, list): contig_str = "/".join(contigs) else: contig_str = contigs L, motif_ranges = parse_contig_string(contig_str) motif_mask = torch.zeros(L, dtype=torch.bool) for start, end in motif_ranges: motif_mask[start:end] = True if input_xyz is not None: motif_xyz = input_xyz else: motif_xyz = None block_state.motif_mask = motif_mask block_state.motif_xyz = motif_xyz block_state.L = L block_state.batch_size = 1 block_state.dtype = torch.float32 self.set_block_state(state, block_state) return components, state class RFDiffusionSetTimestepsStep(ModularPipelineBlocks): """ Set up the EDM noise schedule for RFDiffusion3. """ model_name = "rfdiffusion" @property def description(self) -> str: return "Sets up the EDM noise schedule matching the original inference sampler." @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", description="RFDiffusion3 EDM scheduler"), ] @property def inputs(self) -> List[InputParam]: return [ InputParam( "num_inference_steps", default=None, type_hint=int, description="Number of denoising steps (default: use scheduler config)", ), InputParam("L", required=True, type_hint=int, description="Protein length"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "noise_schedule", type_hint=torch.Tensor, description="EDM noise schedule [num_timesteps] from high to low noise", ), OutputParam( "num_inference_steps", type_hint=int, description="Number of inference steps", ), ] @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) if hasattr(components, "scheduler") and components.scheduler is not None: noise_schedule = components.scheduler.get_noise_schedule() else: # Fallback: simple linear schedule noise_schedule = torch.linspace(160.0 * 16.0, 4e-4 * 16.0, 200) block_state.noise_schedule = noise_schedule block_state.num_inference_steps = len(noise_schedule) self.set_block_state(state, block_state) return components, state class RFDiffusionPrepareLatentsStep(ModularPipelineBlocks): """ Prepare initial noised coordinates for RFDiffusion3. Matches the original _get_initial_structure: noise = c0 * randn(D, L, 3) noise[..., is_motif, :] = 0 X_L = noise + coord_motif """ model_name = "rfdiffusion" @property def description(self) -> str: return ( "Prepares initial coordinates by sampling Gaussian noise scaled by " "the first noise schedule value, matching the original sampler." ) @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", description="RFDiffusion3 EDM scheduler"), ComponentSpec("transformer", description="RFDiffusion transformer model"), ] @property def inputs(self) -> List[InputParam]: return [ InputParam("generator", type_hint=torch.Generator, description="Random generator for reproducibility"), InputParam("diffusion_batch_size", default=1, type_hint=int, description="Number of samples to generate in parallel"), InputParam("L", required=True, type_hint=int, description="Protein length"), InputParam("motif_mask", required=True, type_hint=torch.Tensor), InputParam("motif_xyz", type_hint=torch.Tensor), InputParam("noise_schedule", required=True, type_hint=torch.Tensor), InputParam("dtype", type_hint=torch.dtype), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("xyz", type_hint=torch.Tensor, description="Initial noised coordinates [D, L, 3]"), ] @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) L = block_state.L motif_mask = block_state.motif_mask motif_xyz = block_state.motif_xyz noise_schedule = block_state.noise_schedule dtype = block_state.dtype or torch.float32 generator = block_state.generator D = block_state.diffusion_batch_size or 1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initial noise scaled by first noise level (c0), matching original: # noise = c0 * randn(D, L, 3) c0 = noise_schedule[0] noise = c0 * torch.randn((D, L, 3), dtype=dtype, device=device, generator=generator) # Zero out noise for motif atoms if motif_mask is not None: noise[:, motif_mask] = 0.0 # Build initial coordinates: motif coords + noise coord_motif = torch.zeros((D, L, 3), dtype=dtype, device=device) if motif_xyz is not None and motif_mask is not None: motif_indices = motif_mask.nonzero(as_tuple=True)[0] for i, idx in enumerate(motif_indices): if i < motif_xyz.shape[0]: coord_motif[:, idx] = motif_xyz[i].to(dtype=dtype, device=device) xyz = noise + coord_motif block_state.xyz = xyz self.set_block_state(state, block_state) return components, state