SD15-Surge-V1 / 1model_cascade.py
AbstractPhil's picture
Update 1model_cascade.py
3ef487b verified
from diffusers import DDPMScheduler
import torch, math
# doesn't work like the multimodel scheduler, but it's a fair substitute - alebeit missing some noise functions for this one aka functions like min_snr and so on.
# it's highly basic and doesn't fit the correct paradigm to train sd15 correctly from 0, which is why I didn't bother posting it until now
# there are more robust iterations that will need fitting to create the necessary multi-model scheduler
class SurgeScheduler(DDPMScheduler):
def __init__(self, *args, snr_gamma: float = 5.0, noise_offset: float = 0.1, **kwargs):
super().__init__(
*args,
prediction_type="v_prediction",
**kwargs
)
self.snr_gamma = snr_gamma
self.noise_offset = noise_offset
def add_noise(self, x0, noise, timesteps):
if self.noise_offset:
noise = noise + self.noise_offset * torch.randn_like(noise)
return super().add_noise(x0, noise, timesteps)
def loss_weight(self, timesteps):
alphas = self.alphas_cumprod.to(timesteps.device)[timesteps]
snr = alphas / (1.0 - alphas)
return torch.minimum(self.snr_gamma / snr, torch.ones_like(snr))