# Copyright 2025 Dhruv Nair. All rights reserved. # Licensed under the Apache License, Version 2.0 """ RF3 Scheduler. A diffusers-compatible wrapper around the foundry EDM noise schedule for RF3. Same schedule formula as RFD3 but with gamma_0=0.8 (vs 0.6). """ from typing import Optional import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from rf3.diffusion_samplers.inference_sampler import SampleDiffusion class RF3Scheduler(ConfigMixin): """ Diffusers-compatible scheduler wrapping the foundry RF3 EDM sampler. """ config_name = "config.json" @register_to_config def __init__( self, num_timesteps: int = 200, sigma_data: float = 16.0, s_min: float = 4e-4, s_max: float = 160.0, p: float = 7.0, gamma_0: float = 0.8, gamma_min: float = 1.0, noise_scale: float = 1.003, step_scale: float = 1.5, ): self._sampler = SampleDiffusion( num_timesteps=num_timesteps, min_t=0, max_t=1, sigma_data=sigma_data, s_min=s_min, s_max=s_max, p=p, gamma_0=gamma_0, gamma_min=gamma_min, noise_scale=noise_scale, step_scale=step_scale, solver="af3", ) @property def sampler(self) -> SampleDiffusion: return self._sampler def get_noise_schedule(self, device: torch.device = None) -> torch.Tensor: """Construct the EDM noise schedule.""" return self._sampler._construct_inference_noise_schedule( device=device or torch.device("cpu") ) def add_noise( self, xyz: torch.Tensor, c_t_minus_1: torch.Tensor, c_t: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Inject stochastic noise before the model call.""" gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0 t_hat = c_t_minus_1 * (gamma + 1.0) noise_std = self._sampler.noise_scale * torch.sqrt(t_hat**2 - c_t_minus_1**2) epsilon = noise_std * torch.randn_like(xyz) return xyz + epsilon, t_hat def step( self, xyz_pred: torch.Tensor, xyz_noisy: torch.Tensor, c_t_minus_1: torch.Tensor, c_t: torch.Tensor, ) -> torch.Tensor: """Perform one Euler denoising step.""" gamma = self._sampler.gamma_0 if c_t > self._sampler.gamma_min else 0.0 t_hat = c_t_minus_1 * (gamma + 1.0) delta = (xyz_noisy - xyz_pred) / t_hat d_t = c_t - t_hat return xyz_noisy + self._sampler.step_scale * d_t * delta