dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
4900749 verified
# Copyright 2025 Dhruv Nair. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RFDiffusion3 Scheduler.
A thin diffusers-compatible wrapper around the foundry EDM noise schedule
and stochastic sampling logic from `rfd3.model.inference_sampler`.
"""
from typing import Optional
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
# Reuse the original noise schedule and sampling config directly
from rfd3.model.inference_sampler import SampleDiffusionWithMotif
class RFDiffusionScheduler(ConfigMixin):
"""
Diffusers-compatible scheduler wrapping the foundry EDM sampler.
Delegates noise schedule construction and sampling parameters to
`rfd3.model.inference_sampler.SampleDiffusionWithMotif`.
"""
config_name = "config.json"
@register_to_config
def __init__(
self,
num_timesteps: int = 200,
sigma_data: float = 16.0,
s_min: float = 4e-4,
s_max: float = 160.0,
p: float = 7.0,
gamma_0: float = 0.6,
gamma_min: float = 1.0,
noise_scale: float = 1.003,
step_scale: float = 1.5,
):
# Instantiate the foundry sampler with matching parameters
self._sampler = SampleDiffusionWithMotif(
num_timesteps=num_timesteps,
sigma_data=sigma_data,
s_min=s_min,
s_max=s_max,
p=p,
gamma_0=gamma_0,
gamma_min=gamma_min,
noise_scale=noise_scale,
step_scale=step_scale,
)
@property
def sampler(self) -> SampleDiffusionWithMotif:
return self._sampler
def get_noise_schedule(self, device: torch.device = None) -> torch.Tensor:
"""
Construct the EDM noise schedule using the foundry implementation.
Returns:
torch.Tensor: Noise schedule [num_timesteps] from high to low noise.
"""
return self._sampler._construct_inference_noise_schedule(
device=device or torch.device("cpu")
)
def get_initial_noise_level(self, device: torch.device = None) -> torch.Tensor:
"""Get the first (largest) noise level from the schedule."""
return self.get_noise_schedule(device=device)[0]
def step(
self,
xyz_pred: torch.Tensor,
xyz_noisy: torch.Tensor,
c_t_minus_1: torch.Tensor,
c_t: torch.Tensor,
motif_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Perform one Euler denoising step matching the foundry sampler.
The foundry ``sample_diffusion_like_af3`` does NOT clamp motif
coordinates after the Euler update — it relies on noise injection
having zeroed epsilon for motif atoms so the model's delta is ~0
there. We replicate that behaviour here.
Args:
xyz_pred: Model's denoised prediction X_denoised_L [B, L, 3]
xyz_noisy: Noise-injected coordinates X_noisy_L [B, L, 3]
c_t_minus_1: Previous noise level
c_t: Next (lower) noise level
motif_mask: Boolean mask for fixed positions (True = fixed) [L]
(unused — kept for API compatibility)
Returns:
Updated coordinates X_L [B, L, 3]
"""
gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0
t_hat = c_t_minus_1 * (gamma + 1.0)
delta_L = (xyz_noisy - xyz_pred) / t_hat
d_t = c_t - t_hat
xyz_next = xyz_noisy + self._sampler.step_scale * d_t * delta_L
return xyz_next
def add_noise(
self,
xyz: torch.Tensor,
c_t_minus_1: torch.Tensor,
c_t: torch.Tensor,
motif_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Inject stochastic noise before the model call, matching the foundry sampler.
Args:
xyz: Current coordinates X_L [B, L, 3]
c_t_minus_1: Previous noise level
c_t: Current (next lower) noise level
motif_mask: Boolean mask for fixed positions (True = fixed) [L]
Returns:
Tuple of (noisy coordinates X_noisy_L, t_hat scalar)
"""
gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0
t_hat = c_t_minus_1 * (gamma + 1.0)
noise_std = self._sampler.noise_scale * torch.sqrt(t_hat**2 - c_t_minus_1**2)
epsilon = noise_std * torch.randn_like(xyz)
if motif_mask is not None:
epsilon[:, motif_mask] = 0.0
xyz_noisy = xyz + epsilon
return xyz_noisy, t_hat