File size: 252 Bytes
b6b6742
 
 
 
 
 
 
1
2
3
4
5
6
7
8
def denoise_add_noise(x, t, pred_noise, z=None):
    if z is None:
        z = torch.randn_like(x)
    noise = betas.sqrt()[t] * z
    mean = (x - pred_noise * ((1 - alphas[t]) / (1 - alphas_hat[t]).sqrt())) / alphas[t].sqrt()
    return mean + noise