|
|
from diffusers import DDPMScheduler |
|
|
import torch, math |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SurgeScheduler(DDPMScheduler): |
|
|
def __init__(self, *args, snr_gamma: float = 5.0, noise_offset: float = 0.1, **kwargs): |
|
|
super().__init__( |
|
|
*args, |
|
|
prediction_type="v_prediction", |
|
|
**kwargs |
|
|
) |
|
|
self.snr_gamma = snr_gamma |
|
|
self.noise_offset = noise_offset |
|
|
|
|
|
def add_noise(self, x0, noise, timesteps): |
|
|
if self.noise_offset: |
|
|
noise = noise + self.noise_offset * torch.randn_like(noise) |
|
|
return super().add_noise(x0, noise, timesteps) |
|
|
|
|
|
def loss_weight(self, timesteps): |
|
|
alphas = self.alphas_cumprod.to(timesteps.device)[timesteps] |
|
|
snr = alphas / (1.0 - alphas) |
|
|
return torch.minimum(self.snr_gamma / snr, torch.ones_like(snr)) |
|
|
|