Fix SDE inference: batch size mismatch with CFG and incorrect timestep schedule
#1
by
FabioSarracino
- opened
modeling_acestep_v15_base.py
CHANGED
|
@@ -1949,8 +1949,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1949 |
# Stochastic Differential Equation: predict clean, then re-add noise
|
| 1950 |
t_curr_bsz = t_curr * torch.ones((bsz,), device=device, dtype=dtype)
|
| 1951 |
pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
|
| 1952 |
-
|
| 1953 |
-
xt = self.renoise(pred_clean, next_timestep)
|
| 1954 |
elif infer_method == "ode":
|
| 1955 |
# Ordinary Differential Equation: Euler method
|
| 1956 |
# dx/dt = -v, so x_{t+1} = x_t - v_t * dt
|
|
|
|
| 1949 |
# Stochastic Differential Equation: predict clean, then re-add noise
|
| 1950 |
t_curr_bsz = t_curr * torch.ones((bsz,), device=device, dtype=dtype)
|
| 1951 |
pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
|
| 1952 |
+
xt = self.renoise(pred_clean, t_prev)
|
|
|
|
| 1953 |
elif infer_method == "ode":
|
| 1954 |
# Ordinary Differential Equation: Euler method
|
| 1955 |
# dx/dt = -v, so x_{t+1} = x_t - v_t * dt
|