RFDiffusion-3 / denoise.py
dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
4900749 verified
# Copyright 2025 Dhruv Nair. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Denoising loop for RFDiffusion3.
Implements the iterative denoising procedure from the original
inference_sampler.py (SampleDiffusionWithMotif.sample_diffusion_like_af3).
The loop iterates over consecutive pairs (c_t_minus_1, c_t) in the noise schedule:
1. Inject stochastic noise: t_hat = c_t_minus_1 * (gamma + 1), epsilon ~ N(0, noise_scale * sqrt(t_hat^2 - c_t_minus_1^2))
2. Call model: X_denoised = model(X_noisy, t_hat)
3. Euler update: X = X_noisy + step_scale * (c_t - t_hat) * (X_noisy - X_denoised) / t_hat
"""
from typing import Callable, List
import torch
from diffusers.utils import logging
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__)
class RFDiffusionDenoiseStep(ModularPipelineBlocks):
"""
Iterative denoising step for RFDiffusion3.
Implements the EDM stochastic sampling loop matching the original
SampleDiffusionWithMotif.sample_diffusion_like_af3.
"""
model_name = "rfdiffusion"
@property
def description(self) -> str:
return (
"Iteratively denoise protein structure through reverse diffusion. "
"Uses EDM stochastic sampling with gamma noise injection and step scaling."
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("transformer", description="RFDiffusion transformer for structure prediction"),
ComponentSpec("scheduler", description="Scheduler for noise injection and stepping"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"n_recycle",
default=None,
type_hint=int,
description="Number of recycling iterations (None uses model default)",
),
InputParam(
"callback",
type_hint=Callable,
description="Optional callback function called at each step",
),
InputParam(
"callback_steps",
default=1,
type_hint=int,
description="Frequency of callback invocation",
),
InputParam("xyz", required=True, type_hint=torch.Tensor, description="Initial noised coordinates [D, L, 3]"),
InputParam("noise_schedule", required=True, type_hint=torch.Tensor, description="EDM noise schedule"),
InputParam("motif_mask", required=True, type_hint=torch.Tensor, description="Mask for fixed motif positions"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("xyz", type_hint=torch.Tensor, description="Denoised coordinates [D, L, 3]"),
OutputParam("single", type_hint=torch.Tensor, description="Single representation"),
OutputParam("pair", type_hint=torch.Tensor, description="Pair representation"),
OutputParam("sequence_logits", type_hint=torch.Tensor, description="Predicted sequence logits"),
OutputParam("sequence_indices", type_hint=torch.Tensor, description="Predicted sequence indices"),
OutputParam("trajectory", type_hint=List[torch.Tensor], description="Denoising trajectory"),
]
@torch.no_grad()
def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
xyz = block_state.xyz
noise_schedule = block_state.noise_schedule
motif_mask = block_state.motif_mask
n_recycle = block_state.n_recycle
callback = block_state.callback
callback_steps = block_state.callback_steps or 1
X_denoised_L_traj = []
X_L = xyz.clone()
D = X_L.shape[0]
device = X_L.device
# Ensure all tensors are on the same device as xyz
noise_schedule = noise_schedule.to(device)
if motif_mask is not None:
motif_mask = motif_mask.to(device)
single = None
pair = None
sequence_logits = None
sequence_indices = None
has_transformer = hasattr(components, "transformer") and components.transformer is not None
has_scheduler = hasattr(components, "scheduler") and components.scheduler is not None
# Iterate over consecutive pairs (c_t_minus_1, c_t) in the noise schedule
# noise_schedule goes from high noise to low noise
for step_num in range(len(noise_schedule) - 1):
c_t_minus_1 = noise_schedule[step_num]
c_t = noise_schedule[step_num + 1]
# Step 1: Inject stochastic noise (matching original sampler)
if has_scheduler:
X_noisy_L, t_hat = components.scheduler.add_noise(
X_L, c_t_minus_1, c_t, motif_mask=motif_mask
)
else:
X_noisy_L = X_L
t_hat = c_t_minus_1
# Step 2: Model forward pass
if has_transformer:
# t_hat is a scalar, tile to batch dimension
t_batch = (t_hat.to(device).expand(D) if isinstance(t_hat, torch.Tensor)
else torch.full((D,), t_hat, device=device))
output = components.transformer(
xyz_noisy=X_noisy_L,
t=t_batch,
motif_mask=motif_mask,
n_recycle=n_recycle,
)
X_denoised_L = output.xyz
single = output.single
pair = output.pair
sequence_logits = output.sequence_logits
sequence_indices = output.sequence_indices
else:
X_denoised_L = X_noisy_L
# Step 3: Euler update with step_scale (matching original sampler)
if has_scheduler:
X_L = components.scheduler.step(
xyz_pred=X_denoised_L,
xyz_noisy=X_noisy_L,
c_t_minus_1=c_t_minus_1,
c_t=c_t,
motif_mask=motif_mask,
)
else:
# Fallback simple Euler step
delta_L = (X_noisy_L - X_denoised_L) / (t_hat + 1e-8)
d_t = c_t - t_hat
X_L = X_noisy_L + d_t * delta_L
X_denoised_L_traj.append(X_denoised_L.clone())
if callback is not None and step_num % callback_steps == 0:
callback(step_num, c_t_minus_1, X_L)
block_state.xyz = X_L
block_state.single = single
block_state.pair = pair
block_state.sequence_logits = sequence_logits
block_state.sequence_indices = sequence_indices
block_state.trajectory = X_denoised_L_traj
self.set_block_state(state, block_state)
return components, state