RFDiffusion-3 / before_denoise.py
dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
4900749 verified
# Copyright 2025 Dhruv Nair. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Union
import torch
from diffusers.utils import logging
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__)
def parse_contig_string(contig_str: str) -> Tuple[int, List[Tuple[int, int]]]:
"""
Parse contig specification string.
Supports formats like:
- "100" -> 100 residues to design
- "50-100" -> random length between 50-100
- "A10-25/50" -> motif from chain A residues 10-25, plus 50 designed
Returns:
total_length: Total protein length
motif_ranges: List of (start, end) for motif residues (0-indexed)
"""
parts = contig_str.split("/")
total_length = 0
motif_ranges = []
for part in parts:
part = part.strip()
if not part:
continue
if part[0].isalpha():
chain = part[0]
residue_spec = part[1:]
if "-" in residue_spec:
start, end = map(int, residue_spec.split("-"))
else:
start = end = int(residue_spec)
motif_len = end - start + 1
motif_ranges.append((total_length, total_length + motif_len))
total_length += motif_len
else:
if "-" in part:
min_len, max_len = map(int, part.split("-"))
add_len = (min_len + max_len) // 2
else:
add_len = int(part)
total_length += add_len
return total_length, motif_ranges
class RFDiffusionInputStep(ModularPipelineBlocks):
"""
Input processing step for RFDiffusion.
Parses contigs to prepare features for structure generation.
"""
model_name = "rfdiffusion"
@property
def description(self) -> str:
return (
"Input processing step that:\n"
" 1. Parses contig specification to determine protein length and design regions\n"
" 2. Generates masks for motif positions\n"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"contigs",
required=True,
type_hint=Union[str, List[str]],
description="Contig specification defining design regions (e.g., '100' or 'A10-25/50-100')",
),
InputParam(
"input_xyz",
type_hint=torch.Tensor,
description="Input coordinates for motif residues [N_motif, 3]",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"motif_mask",
type_hint=torch.Tensor,
description="Boolean mask for motif (fixed) positions",
),
OutputParam(
"motif_xyz",
type_hint=torch.Tensor,
description="Coordinates for motif residues",
),
OutputParam(
"L",
type_hint=int,
description="Total length of the protein being designed",
),
OutputParam(
"batch_size",
type_hint=int,
description="Batch size (typically 1 for RFDiffusion)",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type for tensors",
),
]
def check_inputs(self, components, block_state):
if block_state.contigs is None:
raise ValueError("`contigs` must be provided to specify protein design regions")
@torch.no_grad()
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
contigs = block_state.contigs
input_xyz = block_state.input_xyz
if isinstance(contigs, list):
contig_str = "/".join(contigs)
else:
contig_str = contigs
L, motif_ranges = parse_contig_string(contig_str)
motif_mask = torch.zeros(L, dtype=torch.bool)
for start, end in motif_ranges:
motif_mask[start:end] = True
if input_xyz is not None:
motif_xyz = input_xyz
else:
motif_xyz = None
block_state.motif_mask = motif_mask
block_state.motif_xyz = motif_xyz
block_state.L = L
block_state.batch_size = 1
block_state.dtype = torch.float32
self.set_block_state(state, block_state)
return components, state
class RFDiffusionSetTimestepsStep(ModularPipelineBlocks):
"""
Set up the EDM noise schedule for RFDiffusion3.
"""
model_name = "rfdiffusion"
@property
def description(self) -> str:
return "Sets up the EDM noise schedule matching the original inference sampler."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", description="RFDiffusion3 EDM scheduler"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"num_inference_steps",
default=None,
type_hint=int,
description="Number of denoising steps (default: use scheduler config)",
),
InputParam("L", required=True, type_hint=int, description="Protein length"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"noise_schedule",
type_hint=torch.Tensor,
description="EDM noise schedule [num_timesteps] from high to low noise",
),
OutputParam(
"num_inference_steps",
type_hint=int,
description="Number of inference steps",
),
]
@torch.no_grad()
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if hasattr(components, "scheduler") and components.scheduler is not None:
noise_schedule = components.scheduler.get_noise_schedule()
else:
# Fallback: simple linear schedule
noise_schedule = torch.linspace(160.0 * 16.0, 4e-4 * 16.0, 200)
block_state.noise_schedule = noise_schedule
block_state.num_inference_steps = len(noise_schedule)
self.set_block_state(state, block_state)
return components, state
class RFDiffusionPrepareLatentsStep(ModularPipelineBlocks):
"""
Prepare initial noised coordinates for RFDiffusion3.
Matches the original _get_initial_structure:
noise = c0 * randn(D, L, 3)
noise[..., is_motif, :] = 0
X_L = noise + coord_motif
"""
model_name = "rfdiffusion"
@property
def description(self) -> str:
return (
"Prepares initial coordinates by sampling Gaussian noise scaled by "
"the first noise schedule value, matching the original sampler."
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", description="RFDiffusion3 EDM scheduler"),
ComponentSpec("transformer", description="RFDiffusion transformer model"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("generator", type_hint=torch.Generator, description="Random generator for reproducibility"),
InputParam("diffusion_batch_size", default=1, type_hint=int, description="Number of samples to generate in parallel"),
InputParam("L", required=True, type_hint=int, description="Protein length"),
InputParam("motif_mask", required=True, type_hint=torch.Tensor),
InputParam("motif_xyz", type_hint=torch.Tensor),
InputParam("noise_schedule", required=True, type_hint=torch.Tensor),
InputParam("dtype", type_hint=torch.dtype),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("xyz", type_hint=torch.Tensor, description="Initial noised coordinates [D, L, 3]"),
]
@torch.no_grad()
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
L = block_state.L
motif_mask = block_state.motif_mask
motif_xyz = block_state.motif_xyz
noise_schedule = block_state.noise_schedule
dtype = block_state.dtype or torch.float32
generator = block_state.generator
D = block_state.diffusion_batch_size or 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initial noise scaled by first noise level (c0), matching original:
# noise = c0 * randn(D, L, 3)
c0 = noise_schedule[0]
noise = c0 * torch.randn((D, L, 3), dtype=dtype, device=device, generator=generator)
# Zero out noise for motif atoms
if motif_mask is not None:
noise[:, motif_mask] = 0.0
# Build initial coordinates: motif coords + noise
coord_motif = torch.zeros((D, L, 3), dtype=dtype, device=device)
if motif_xyz is not None and motif_mask is not None:
motif_indices = motif_mask.nonzero(as_tuple=True)[0]
for i, idx in enumerate(motif_indices):
if i < motif_xyz.shape[0]:
coord_motif[:, idx] = motif_xyz[i].to(dtype=dtype, device=device)
xyz = noise + coord_motif
block_state.xyz = xyz
self.set_block_state(state, block_state)
return components, state