| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Denoising loop for RFDiffusion3. |
| |
| Implements the iterative denoising procedure from the original |
| inference_sampler.py (SampleDiffusionWithMotif.sample_diffusion_like_af3). |
| |
| The loop iterates over consecutive pairs (c_t_minus_1, c_t) in the noise schedule: |
| 1. Inject stochastic noise: t_hat = c_t_minus_1 * (gamma + 1), epsilon ~ N(0, noise_scale * sqrt(t_hat^2 - c_t_minus_1^2)) |
| 2. Call model: X_denoised = model(X_noisy, t_hat) |
| 3. Euler update: X = X_noisy + step_scale * (c_t - t_hat) * (X_noisy - X_denoised) / t_hat |
| """ |
|
|
| from typing import Callable, List |
|
|
| import torch |
|
|
| from diffusers.utils import logging |
| from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState |
| from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class RFDiffusionDenoiseStep(ModularPipelineBlocks): |
| """ |
| Iterative denoising step for RFDiffusion3. |
| |
| Implements the EDM stochastic sampling loop matching the original |
| SampleDiffusionWithMotif.sample_diffusion_like_af3. |
| """ |
|
|
| model_name = "rfdiffusion" |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Iteratively denoise protein structure through reverse diffusion. " |
| "Uses EDM stochastic sampling with gamma noise injection and step scaling." |
| ) |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec("transformer", description="RFDiffusion transformer for structure prediction"), |
| ComponentSpec("scheduler", description="Scheduler for noise injection and stepping"), |
| ] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "n_recycle", |
| default=None, |
| type_hint=int, |
| description="Number of recycling iterations (None uses model default)", |
| ), |
| InputParam( |
| "callback", |
| type_hint=Callable, |
| description="Optional callback function called at each step", |
| ), |
| InputParam( |
| "callback_steps", |
| default=1, |
| type_hint=int, |
| description="Frequency of callback invocation", |
| ), |
| InputParam("xyz", required=True, type_hint=torch.Tensor, description="Initial noised coordinates [D, L, 3]"), |
| InputParam("noise_schedule", required=True, type_hint=torch.Tensor, description="EDM noise schedule"), |
| InputParam("motif_mask", required=True, type_hint=torch.Tensor, description="Mask for fixed motif positions"), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam("xyz", type_hint=torch.Tensor, description="Denoised coordinates [D, L, 3]"), |
| OutputParam("single", type_hint=torch.Tensor, description="Single representation"), |
| OutputParam("pair", type_hint=torch.Tensor, description="Pair representation"), |
| OutputParam("sequence_logits", type_hint=torch.Tensor, description="Predicted sequence logits"), |
| OutputParam("sequence_indices", type_hint=torch.Tensor, description="Predicted sequence indices"), |
| OutputParam("trajectory", type_hint=List[torch.Tensor], description="Denoising trajectory"), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| xyz = block_state.xyz |
| noise_schedule = block_state.noise_schedule |
| motif_mask = block_state.motif_mask |
|
|
| n_recycle = block_state.n_recycle |
| callback = block_state.callback |
| callback_steps = block_state.callback_steps or 1 |
|
|
| X_denoised_L_traj = [] |
| X_L = xyz.clone() |
| D = X_L.shape[0] |
| device = X_L.device |
|
|
| |
| noise_schedule = noise_schedule.to(device) |
| if motif_mask is not None: |
| motif_mask = motif_mask.to(device) |
|
|
| single = None |
| pair = None |
| sequence_logits = None |
| sequence_indices = None |
|
|
| has_transformer = hasattr(components, "transformer") and components.transformer is not None |
| has_scheduler = hasattr(components, "scheduler") and components.scheduler is not None |
|
|
| |
| |
| for step_num in range(len(noise_schedule) - 1): |
| c_t_minus_1 = noise_schedule[step_num] |
| c_t = noise_schedule[step_num + 1] |
|
|
| |
| if has_scheduler: |
| X_noisy_L, t_hat = components.scheduler.add_noise( |
| X_L, c_t_minus_1, c_t, motif_mask=motif_mask |
| ) |
| else: |
| X_noisy_L = X_L |
| t_hat = c_t_minus_1 |
|
|
| |
| if has_transformer: |
| |
| t_batch = (t_hat.to(device).expand(D) if isinstance(t_hat, torch.Tensor) |
| else torch.full((D,), t_hat, device=device)) |
|
|
| output = components.transformer( |
| xyz_noisy=X_noisy_L, |
| t=t_batch, |
| motif_mask=motif_mask, |
| n_recycle=n_recycle, |
| ) |
|
|
| X_denoised_L = output.xyz |
| single = output.single |
| pair = output.pair |
| sequence_logits = output.sequence_logits |
| sequence_indices = output.sequence_indices |
| else: |
| X_denoised_L = X_noisy_L |
|
|
| |
| if has_scheduler: |
| X_L = components.scheduler.step( |
| xyz_pred=X_denoised_L, |
| xyz_noisy=X_noisy_L, |
| c_t_minus_1=c_t_minus_1, |
| c_t=c_t, |
| motif_mask=motif_mask, |
| ) |
| else: |
| |
| delta_L = (X_noisy_L - X_denoised_L) / (t_hat + 1e-8) |
| d_t = c_t - t_hat |
| X_L = X_noisy_L + d_t * delta_L |
|
|
| X_denoised_L_traj.append(X_denoised_L.clone()) |
|
|
| if callback is not None and step_num % callback_steps == 0: |
| callback(step_num, c_t_minus_1, X_L) |
|
|
| block_state.xyz = X_L |
| block_state.single = single |
| block_state.pair = pair |
| block_state.sequence_logits = sequence_logits |
| block_state.sequence_indices = sequence_indices |
| block_state.trajectory = X_denoised_L_traj |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|