| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import List, Tuple, Union |
|
|
| import torch |
|
|
| from diffusers.utils import logging |
| from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState |
| from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def parse_contig_string(contig_str: str) -> Tuple[int, List[Tuple[int, int]]]: |
| """ |
| Parse contig specification string. |
| |
| Supports formats like: |
| - "100" -> 100 residues to design |
| - "50-100" -> random length between 50-100 |
| - "A10-25/50" -> motif from chain A residues 10-25, plus 50 designed |
| |
| Returns: |
| total_length: Total protein length |
| motif_ranges: List of (start, end) for motif residues (0-indexed) |
| """ |
| parts = contig_str.split("/") |
| total_length = 0 |
| motif_ranges = [] |
|
|
| for part in parts: |
| part = part.strip() |
| if not part: |
| continue |
|
|
| if part[0].isalpha(): |
| chain = part[0] |
| residue_spec = part[1:] |
| if "-" in residue_spec: |
| start, end = map(int, residue_spec.split("-")) |
| else: |
| start = end = int(residue_spec) |
| motif_len = end - start + 1 |
| motif_ranges.append((total_length, total_length + motif_len)) |
| total_length += motif_len |
| else: |
| if "-" in part: |
| min_len, max_len = map(int, part.split("-")) |
| add_len = (min_len + max_len) // 2 |
| else: |
| add_len = int(part) |
| total_length += add_len |
|
|
| return total_length, motif_ranges |
|
|
|
|
| class RFDiffusionInputStep(ModularPipelineBlocks): |
| """ |
| Input processing step for RFDiffusion. |
| |
| Parses contigs to prepare features for structure generation. |
| """ |
|
|
| model_name = "rfdiffusion" |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Input processing step that:\n" |
| " 1. Parses contig specification to determine protein length and design regions\n" |
| " 2. Generates masks for motif positions\n" |
| ) |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "contigs", |
| required=True, |
| type_hint=Union[str, List[str]], |
| description="Contig specification defining design regions (e.g., '100' or 'A10-25/50-100')", |
| ), |
| InputParam( |
| "input_xyz", |
| type_hint=torch.Tensor, |
| description="Input coordinates for motif residues [N_motif, 3]", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "motif_mask", |
| type_hint=torch.Tensor, |
| description="Boolean mask for motif (fixed) positions", |
| ), |
| OutputParam( |
| "motif_xyz", |
| type_hint=torch.Tensor, |
| description="Coordinates for motif residues", |
| ), |
| OutputParam( |
| "L", |
| type_hint=int, |
| description="Total length of the protein being designed", |
| ), |
| OutputParam( |
| "batch_size", |
| type_hint=int, |
| description="Batch size (typically 1 for RFDiffusion)", |
| ), |
| OutputParam( |
| "dtype", |
| type_hint=torch.dtype, |
| description="Data type for tensors", |
| ), |
| ] |
|
|
| def check_inputs(self, components, block_state): |
| if block_state.contigs is None: |
| raise ValueError("`contigs` must be provided to specify protein design regions") |
|
|
| @torch.no_grad() |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
| self.check_inputs(components, block_state) |
|
|
| contigs = block_state.contigs |
| input_xyz = block_state.input_xyz |
|
|
| if isinstance(contigs, list): |
| contig_str = "/".join(contigs) |
| else: |
| contig_str = contigs |
|
|
| L, motif_ranges = parse_contig_string(contig_str) |
|
|
| motif_mask = torch.zeros(L, dtype=torch.bool) |
| for start, end in motif_ranges: |
| motif_mask[start:end] = True |
|
|
| if input_xyz is not None: |
| motif_xyz = input_xyz |
| else: |
| motif_xyz = None |
|
|
| block_state.motif_mask = motif_mask |
| block_state.motif_xyz = motif_xyz |
| block_state.L = L |
| block_state.batch_size = 1 |
| block_state.dtype = torch.float32 |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| class RFDiffusionSetTimestepsStep(ModularPipelineBlocks): |
| """ |
| Set up the EDM noise schedule for RFDiffusion3. |
| """ |
|
|
| model_name = "rfdiffusion" |
|
|
| @property |
| def description(self) -> str: |
| return "Sets up the EDM noise schedule matching the original inference sampler." |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec("scheduler", description="RFDiffusion3 EDM scheduler"), |
| ] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "num_inference_steps", |
| default=None, |
| type_hint=int, |
| description="Number of denoising steps (default: use scheduler config)", |
| ), |
| InputParam("L", required=True, type_hint=int, description="Protein length"), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "noise_schedule", |
| type_hint=torch.Tensor, |
| description="EDM noise schedule [num_timesteps] from high to low noise", |
| ), |
| OutputParam( |
| "num_inference_steps", |
| type_hint=int, |
| description="Number of inference steps", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| if hasattr(components, "scheduler") and components.scheduler is not None: |
| noise_schedule = components.scheduler.get_noise_schedule() |
| else: |
| |
| noise_schedule = torch.linspace(160.0 * 16.0, 4e-4 * 16.0, 200) |
|
|
| block_state.noise_schedule = noise_schedule |
| block_state.num_inference_steps = len(noise_schedule) |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| class RFDiffusionPrepareLatentsStep(ModularPipelineBlocks): |
| """ |
| Prepare initial noised coordinates for RFDiffusion3. |
| |
| Matches the original _get_initial_structure: |
| noise = c0 * randn(D, L, 3) |
| noise[..., is_motif, :] = 0 |
| X_L = noise + coord_motif |
| """ |
|
|
| model_name = "rfdiffusion" |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Prepares initial coordinates by sampling Gaussian noise scaled by " |
| "the first noise schedule value, matching the original sampler." |
| ) |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec("scheduler", description="RFDiffusion3 EDM scheduler"), |
| ComponentSpec("transformer", description="RFDiffusion transformer model"), |
| ] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam("generator", type_hint=torch.Generator, description="Random generator for reproducibility"), |
| InputParam("diffusion_batch_size", default=1, type_hint=int, description="Number of samples to generate in parallel"), |
| InputParam("L", required=True, type_hint=int, description="Protein length"), |
| InputParam("motif_mask", required=True, type_hint=torch.Tensor), |
| InputParam("motif_xyz", type_hint=torch.Tensor), |
| InputParam("noise_schedule", required=True, type_hint=torch.Tensor), |
| InputParam("dtype", type_hint=torch.dtype), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam("xyz", type_hint=torch.Tensor, description="Initial noised coordinates [D, L, 3]"), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| L = block_state.L |
| motif_mask = block_state.motif_mask |
| motif_xyz = block_state.motif_xyz |
| noise_schedule = block_state.noise_schedule |
| dtype = block_state.dtype or torch.float32 |
| generator = block_state.generator |
| D = block_state.diffusion_batch_size or 1 |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| |
| c0 = noise_schedule[0] |
| noise = c0 * torch.randn((D, L, 3), dtype=dtype, device=device, generator=generator) |
|
|
| |
| if motif_mask is not None: |
| noise[:, motif_mask] = 0.0 |
|
|
| |
| coord_motif = torch.zeros((D, L, 3), dtype=dtype, device=device) |
| if motif_xyz is not None and motif_mask is not None: |
| motif_indices = motif_mask.nonzero(as_tuple=True)[0] |
| for i, idx in enumerate(motif_indices): |
| if i < motif_xyz.shape[0]: |
| coord_motif[:, idx] = motif_xyz[i].to(dtype=dtype, device=device) |
|
|
| xyz = noise + coord_motif |
|
|
| block_state.xyz = xyz |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|