# Copyright 2025 Dhruv Nair. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ RFDiffusion3 Scheduler. A thin diffusers-compatible wrapper around the foundry EDM noise schedule and stochastic sampling logic from `rfd3.model.inference_sampler`. """ from typing import Optional import torch from diffusers.configuration_utils import ConfigMixin, register_to_config # Reuse the original noise schedule and sampling config directly from rfd3.model.inference_sampler import SampleDiffusionWithMotif class RFDiffusionScheduler(ConfigMixin): """ Diffusers-compatible scheduler wrapping the foundry EDM sampler. Delegates noise schedule construction and sampling parameters to `rfd3.model.inference_sampler.SampleDiffusionWithMotif`. """ config_name = "config.json" @register_to_config def __init__( self, num_timesteps: int = 200, sigma_data: float = 16.0, s_min: float = 4e-4, s_max: float = 160.0, p: float = 7.0, gamma_0: float = 0.6, gamma_min: float = 1.0, noise_scale: float = 1.003, step_scale: float = 1.5, ): # Instantiate the foundry sampler with matching parameters self._sampler = SampleDiffusionWithMotif( num_timesteps=num_timesteps, sigma_data=sigma_data, s_min=s_min, s_max=s_max, p=p, gamma_0=gamma_0, gamma_min=gamma_min, noise_scale=noise_scale, step_scale=step_scale, ) @property def sampler(self) -> SampleDiffusionWithMotif: return self._sampler def get_noise_schedule(self, device: torch.device = None) -> torch.Tensor: """ Construct the EDM noise schedule using the foundry implementation. Returns: torch.Tensor: Noise schedule [num_timesteps] from high to low noise. """ return self._sampler._construct_inference_noise_schedule( device=device or torch.device("cpu") ) def get_initial_noise_level(self, device: torch.device = None) -> torch.Tensor: """Get the first (largest) noise level from the schedule.""" return self.get_noise_schedule(device=device)[0] def step( self, xyz_pred: torch.Tensor, xyz_noisy: torch.Tensor, c_t_minus_1: torch.Tensor, c_t: torch.Tensor, motif_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Perform one Euler denoising step matching the foundry sampler. The foundry ``sample_diffusion_like_af3`` does NOT clamp motif coordinates after the Euler update — it relies on noise injection having zeroed epsilon for motif atoms so the model's delta is ~0 there. We replicate that behaviour here. Args: xyz_pred: Model's denoised prediction X_denoised_L [B, L, 3] xyz_noisy: Noise-injected coordinates X_noisy_L [B, L, 3] c_t_minus_1: Previous noise level c_t: Next (lower) noise level motif_mask: Boolean mask for fixed positions (True = fixed) [L] (unused — kept for API compatibility) Returns: Updated coordinates X_L [B, L, 3] """ gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0 t_hat = c_t_minus_1 * (gamma + 1.0) delta_L = (xyz_noisy - xyz_pred) / t_hat d_t = c_t - t_hat xyz_next = xyz_noisy + self._sampler.step_scale * d_t * delta_L return xyz_next def add_noise( self, xyz: torch.Tensor, c_t_minus_1: torch.Tensor, c_t: torch.Tensor, motif_mask: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Inject stochastic noise before the model call, matching the foundry sampler. Args: xyz: Current coordinates X_L [B, L, 3] c_t_minus_1: Previous noise level c_t: Current (next lower) noise level motif_mask: Boolean mask for fixed positions (True = fixed) [L] Returns: Tuple of (noisy coordinates X_noisy_L, t_hat scalar) """ gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0 t_hat = c_t_minus_1 * (gamma + 1.0) noise_std = self._sampler.noise_scale * torch.sqrt(t_hat**2 - c_t_minus_1**2) epsilon = noise_std * torch.randn_like(xyz) if motif_mask is not None: epsilon[:, motif_mask] = 0.0 xyz_noisy = xyz + epsilon return xyz_noisy, t_hat