File size: 1,211 Bytes
1d3821a
 
 
3ef487b
78742f1
 
1d3821a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from diffusers import DDPMScheduler
import torch, math

# doesn't work like the multimodel scheduler, but it's a fair substitute - alebeit missing some noise functions for this one aka functions like min_snr and so on.
# it's highly basic and doesn't fit the correct paradigm to train sd15 correctly from 0, which is why I didn't bother posting it until now
# there are more robust iterations that will need fitting to create the necessary multi-model scheduler
class SurgeScheduler(DDPMScheduler):
    def __init__(self, *args, snr_gamma: float = 5.0, noise_offset: float = 0.1, **kwargs):
        super().__init__(
            *args,
            prediction_type="v_prediction",
            **kwargs
        )
        self.snr_gamma = snr_gamma
        self.noise_offset = noise_offset

    def add_noise(self, x0, noise, timesteps):
        if self.noise_offset:
            noise = noise + self.noise_offset * torch.randn_like(noise)
        return super().add_noise(x0, noise, timesteps)

    def loss_weight(self, timesteps):
        alphas = self.alphas_cumprod.to(timesteps.device)[timesteps]
        snr = alphas / (1.0 - alphas)
        return torch.minimum(self.snr_gamma / snr, torch.ones_like(snr))