dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
a376829 verified
# Copyright 2025 Dhruv Nair. All rights reserved.
# Licensed under the Apache License, Version 2.0
"""
RF3 Scheduler.
A diffusers-compatible wrapper around the foundry EDM noise schedule
for RF3. Same schedule formula as RFD3 but with gamma_0=0.8 (vs 0.6).
"""
from typing import Optional
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from rf3.diffusion_samplers.inference_sampler import SampleDiffusion
class RF3Scheduler(ConfigMixin):
"""
Diffusers-compatible scheduler wrapping the foundry RF3 EDM sampler.
"""
config_name = "config.json"
@register_to_config
def __init__(
self,
num_timesteps: int = 200,
sigma_data: float = 16.0,
s_min: float = 4e-4,
s_max: float = 160.0,
p: float = 7.0,
gamma_0: float = 0.8,
gamma_min: float = 1.0,
noise_scale: float = 1.003,
step_scale: float = 1.5,
):
self._sampler = SampleDiffusion(
num_timesteps=num_timesteps,
min_t=0,
max_t=1,
sigma_data=sigma_data,
s_min=s_min,
s_max=s_max,
p=p,
gamma_0=gamma_0,
gamma_min=gamma_min,
noise_scale=noise_scale,
step_scale=step_scale,
solver="af3",
)
@property
def sampler(self) -> SampleDiffusion:
return self._sampler
def get_noise_schedule(self, device: torch.device = None) -> torch.Tensor:
"""Construct the EDM noise schedule."""
return self._sampler._construct_inference_noise_schedule(
device=device or torch.device("cpu")
)
def add_noise(
self,
xyz: torch.Tensor,
c_t_minus_1: torch.Tensor,
c_t: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Inject stochastic noise before the model call."""
gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0
t_hat = c_t_minus_1 * (gamma + 1.0)
noise_std = self._sampler.noise_scale * torch.sqrt(t_hat**2 - c_t_minus_1**2)
epsilon = noise_std * torch.randn_like(xyz)
return xyz + epsilon, t_hat
def step(
self,
xyz_pred: torch.Tensor,
xyz_noisy: torch.Tensor,
c_t_minus_1: torch.Tensor,
c_t: torch.Tensor,
) -> torch.Tensor:
"""Perform one Euler denoising step."""
gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0
t_hat = c_t_minus_1 * (gamma + 1.0)
delta = (xyz_noisy - xyz_pred) / t_hat
d_t = c_t - t_hat
return xyz_noisy + self._sampler.step_scale * d_t * delta