from diffusers import DDPMScheduler import torch, math # doesn't work like the multimodel scheduler, but it's a fair substitute - alebeit missing some noise functions for this one aka functions like min_snr and so on. # it's highly basic and doesn't fit the correct paradigm to train sd15 correctly from 0, which is why I didn't bother posting it until now # there are more robust iterations that will need fitting to create the necessary multi-model scheduler class SurgeScheduler(DDPMScheduler): def __init__(self, *args, snr_gamma: float = 5.0, noise_offset: float = 0.1, **kwargs): super().__init__( *args, prediction_type="v_prediction", **kwargs ) self.snr_gamma = snr_gamma self.noise_offset = noise_offset def add_noise(self, x0, noise, timesteps): if self.noise_offset: noise = noise + self.noise_offset * torch.randn_like(noise) return super().add_noise(x0, noise, timesteps) def loss_weight(self, timesteps): alphas = self.alphas_cumprod.to(timesteps.device)[timesteps] snr = alphas / (1.0 - alphas) return torch.minimum(self.snr_gamma / snr, torch.ones_like(snr))