Eji-Sensei14's picture
Upload folder using huggingface_hub
c6535db verified
import torch
def get_eps_step_official(sigma, sigma_next, eta=0.0):
"""Official-like EPS ancestral step.
Returns (sigma_up, sigma_down). If eta<=0, sigma_up=0 and sigma_down=sigma_next.
Accepts scalar tensors and preserves dtype/device.
"""
if eta is None or float(eta) <= 0.0:
# Ensure tensor outputs matching input dtype/device
return torch.zeros_like(sigma), sigma_next
s = sigma.to(torch.float64)
sn = sigma_next.to(torch.float64)
num = torch.clamp(sn**2 * (torch.clamp(s**2, min=1e-12) - sn**2) / torch.clamp(s**2, min=1e-12), min=0.0)
su = torch.sqrt(num) * float(eta)
su = torch.minimum(su, sn)
sd = torch.sqrt(torch.clamp(sn**2 - su**2, min=0.0))
return su.to(sigma.dtype), sd.to(sigma.dtype)