# Copyright 2025 Dhruv Nair. All rights reserved. # Licensed under the Apache License, Version 2.0 """ Denoising loop for RF3. Same EDM stochastic sampling as RFD3, but conditioned on trunk representations (single S_I, pair Z_II) from the recycling step. """ from typing import Callable, List import torch from diffusers.utils import logging from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam logger = logging.get_logger(__name__) class RF3DenoiseStep(ModularPipelineBlocks): """ Iterative denoising step for RF3. Uses trunk representations from the recycling step as conditioning for the diffusion module at each denoising step. """ model_name = "rf3" @property def description(self) -> str: return "Iteratively denoise protein structure conditioned on sequence/MSA representations." @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("transformer", description="RF3 transformer (provides diffusion_module)"), ComponentSpec("scheduler", description="RF3 EDM scheduler"), ] @property def inputs(self) -> List[InputParam]: return [ InputParam("xyz", required=True, type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"), InputParam("noise_schedule", required=True, type_hint=torch.Tensor), InputParam("f", required=True, type_hint=dict, description="Feature dictionary"), InputParam("single", type_hint=torch.Tensor, description="Trunk single repr [I, c_s]"), InputParam("pair", type_hint=torch.Tensor, description="Trunk pair repr [I, I, c_z]"), InputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"), InputParam("callback", type_hint=Callable), InputParam("callback_steps", default=1, type_hint=int), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam("xyz", type_hint=torch.Tensor, description="Denoised coords [D, L, 3]"), OutputParam("trajectory", type_hint=List[torch.Tensor]), ] @torch.no_grad() def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) xyz = block_state.xyz noise_schedule = block_state.noise_schedule f = block_state.f single = block_state.single pair = block_state.pair s_inputs = block_state.s_inputs callback = block_state.callback callback_steps = block_state.callback_steps or 1 X_L = xyz.clone() D = X_L.shape[0] device = X_L.device noise_schedule = noise_schedule.to(device) trajectory = [] has_transformer = hasattr(components, "transformer") and components.transformer is not None has_scheduler = hasattr(components, "scheduler") and components.scheduler is not None for step_num in range(len(noise_schedule) - 1): c_t_minus_1 = noise_schedule[step_num] c_t = noise_schedule[step_num + 1] # Noise injection if has_scheduler: X_noisy, t_hat = components.scheduler.add_noise(X_L, c_t_minus_1, c_t) else: X_noisy = X_L t_hat = c_t_minus_1 # Model forward (diffusion module conditioned on trunk) if has_transformer: t_batch = (t_hat.to(device).expand(D) if isinstance(t_hat, torch.Tensor) else torch.full((D,), t_hat, device=device)) outs = components.transformer.diffusion_module( X_noisy_L=X_noisy, t=t_batch, f=f, S_inputs_I=s_inputs, S_trunk_I=single, Z_trunk_II=pair, ) X_denoised = outs if isinstance(outs, torch.Tensor) else outs.get("X_L", outs) else: X_denoised = X_noisy # Euler step if has_scheduler: X_L = components.scheduler.step( xyz_pred=X_denoised, xyz_noisy=X_noisy, c_t_minus_1=c_t_minus_1, c_t=c_t, ) else: delta = (X_noisy - X_denoised) / (t_hat + 1e-8) d_t = c_t - t_hat X_L = X_noisy + d_t * delta trajectory.append(X_denoised.clone()) if callback is not None and step_num % callback_steps == 0: callback(step_num, c_t_minus_1, X_L) block_state.xyz = X_L block_state.trajectory = trajectory self.set_block_state(state, block_state) return components, state