| import torch | |
| def get_eps_step_official(sigma, sigma_next, eta=0.0): | |
| """Official-like EPS ancestral step. | |
| Returns (sigma_up, sigma_down). If eta<=0, sigma_up=0 and sigma_down=sigma_next. | |
| Accepts scalar tensors and preserves dtype/device. | |
| """ | |
| if eta is None or float(eta) <= 0.0: | |
| # Ensure tensor outputs matching input dtype/device | |
| return torch.zeros_like(sigma), sigma_next | |
| s = sigma.to(torch.float64) | |
| sn = sigma_next.to(torch.float64) | |
| num = torch.clamp(sn**2 * (torch.clamp(s**2, min=1e-12) - sn**2) / torch.clamp(s**2, min=1e-12), min=0.0) | |
| su = torch.sqrt(num) * float(eta) | |
| su = torch.minimum(su, sn) | |
| sd = torch.sqrt(torch.clamp(sn**2 - su**2, min=0.0)) | |
| return su.to(sigma.dtype), sd.to(sigma.dtype) | |