Fix SDE inference: batch size mismatch with CFG and incorrect timestep schedule

#1
Files changed (1) hide show
  1. modeling_acestep_v15_base.py +1 -2
modeling_acestep_v15_base.py CHANGED
@@ -1949,8 +1949,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1949
  # Stochastic Differential Equation: predict clean, then re-add noise
1950
  t_curr_bsz = t_curr * torch.ones((bsz,), device=device, dtype=dtype)
1951
  pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
1952
- next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
1953
- xt = self.renoise(pred_clean, next_timestep)
1954
  elif infer_method == "ode":
1955
  # Ordinary Differential Equation: Euler method
1956
  # dx/dt = -v, so x_{t+1} = x_t - v_t * dt
 
1949
  # Stochastic Differential Equation: predict clean, then re-add noise
1950
  t_curr_bsz = t_curr * torch.ones((bsz,), device=device, dtype=dtype)
1951
  pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
1952
+ xt = self.renoise(pred_clean, t_prev)
 
1953
  elif infer_method == "ode":
1954
  # Ordinary Differential Equation: Euler method
1955
  # dx/dt = -v, so x_{t+1} = x_t - v_t * dt