# Copyright 2025 Dhruv Nair. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Denoising loop for RFDiffusion3. Implements the iterative denoising procedure from the original inference_sampler.py (SampleDiffusionWithMotif.sample_diffusion_like_af3). The loop iterates over consecutive pairs (c_t_minus_1, c_t) in the noise schedule: 1. Inject stochastic noise: t_hat = c_t_minus_1 * (gamma + 1), epsilon ~ N(0, noise_scale * sqrt(t_hat^2 - c_t_minus_1^2)) 2. Call model: X_denoised = model(X_noisy, t_hat) 3. Euler update: X = X_noisy + step_scale * (c_t - t_hat) * (X_noisy - X_denoised) / t_hat """ from typing import Callable, List import torch from diffusers.utils import logging from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam logger = logging.get_logger(__name__) class RFDiffusionDenoiseStep(ModularPipelineBlocks): """ Iterative denoising step for RFDiffusion3. Implements the EDM stochastic sampling loop matching the original SampleDiffusionWithMotif.sample_diffusion_like_af3. """ model_name = "rfdiffusion" @property def description(self) -> str: return ( "Iteratively denoise protein structure through reverse diffusion. " "Uses EDM stochastic sampling with gamma noise injection and step scaling." ) @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("transformer", description="RFDiffusion transformer for structure prediction"), ComponentSpec("scheduler", description="Scheduler for noise injection and stepping"), ] @property def inputs(self) -> List[InputParam]: return [ InputParam( "n_recycle", default=None, type_hint=int, description="Number of recycling iterations (None uses model default)", ), InputParam( "callback", type_hint=Callable, description="Optional callback function called at each step", ), InputParam( "callback_steps", default=1, type_hint=int, description="Frequency of callback invocation", ), InputParam("xyz", required=True, type_hint=torch.Tensor, description="Initial noised coordinates [D, L, 3]"), InputParam("noise_schedule", required=True, type_hint=torch.Tensor, description="EDM noise schedule"), InputParam("motif_mask", required=True, type_hint=torch.Tensor, description="Mask for fixed motif positions"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("xyz", type_hint=torch.Tensor, description="Denoised coordinates [D, L, 3]"), OutputParam("single", type_hint=torch.Tensor, description="Single representation"), OutputParam("pair", type_hint=torch.Tensor, description="Pair representation"), OutputParam("sequence_logits", type_hint=torch.Tensor, description="Predicted sequence logits"), OutputParam("sequence_indices", type_hint=torch.Tensor, description="Predicted sequence indices"), OutputParam("trajectory", type_hint=List[torch.Tensor], description="Denoising trajectory"), ] @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) xyz = block_state.xyz noise_schedule = block_state.noise_schedule motif_mask = block_state.motif_mask n_recycle = block_state.n_recycle callback = block_state.callback callback_steps = block_state.callback_steps or 1 X_denoised_L_traj = [] X_L = xyz.clone() D = X_L.shape[0] device = X_L.device # Ensure all tensors are on the same device as xyz noise_schedule = noise_schedule.to(device) if motif_mask is not None: motif_mask = motif_mask.to(device) single = None pair = None sequence_logits = None sequence_indices = None has_transformer = hasattr(components, "transformer") and components.transformer is not None has_scheduler = hasattr(components, "scheduler") and components.scheduler is not None # Iterate over consecutive pairs (c_t_minus_1, c_t) in the noise schedule # noise_schedule goes from high noise to low noise for step_num in range(len(noise_schedule) - 1): c_t_minus_1 = noise_schedule[step_num] c_t = noise_schedule[step_num + 1] # Step 1: Inject stochastic noise (matching original sampler) if has_scheduler: X_noisy_L, t_hat = components.scheduler.add_noise( X_L, c_t_minus_1, c_t, motif_mask=motif_mask ) else: X_noisy_L = X_L t_hat = c_t_minus_1 # Step 2: Model forward pass if has_transformer: # t_hat is a scalar, tile to batch dimension t_batch = (t_hat.to(device).expand(D) if isinstance(t_hat, torch.Tensor) else torch.full((D,), t_hat, device=device)) output = components.transformer( xyz_noisy=X_noisy_L, t=t_batch, motif_mask=motif_mask, n_recycle=n_recycle, ) X_denoised_L = output.xyz single = output.single pair = output.pair sequence_logits = output.sequence_logits sequence_indices = output.sequence_indices else: X_denoised_L = X_noisy_L # Step 3: Euler update with step_scale (matching original sampler) if has_scheduler: X_L = components.scheduler.step( xyz_pred=X_denoised_L, xyz_noisy=X_noisy_L, c_t_minus_1=c_t_minus_1, c_t=c_t, motif_mask=motif_mask, ) else: # Fallback simple Euler step delta_L = (X_noisy_L - X_denoised_L) / (t_hat + 1e-8) d_t = c_t - t_hat X_L = X_noisy_L + d_t * delta_L X_denoised_L_traj.append(X_denoised_L.clone()) if callback is not None and step_num % callback_steps == 0: callback(step_num, c_t_minus_1, X_L) block_state.xyz = X_L block_state.single = single block_state.pair = pair block_state.sequence_logits = sequence_logits block_state.sequence_indices = sequence_indices block_state.trajectory = X_denoised_L_traj self.set_block_state(state, block_state) return components, state