Update modeling_acestep_v15_base.py
Browse filesfix issue: https://github.com/ace-step/ACE-Step-1.5/issues/214
modeling_acestep_v15_base.py
CHANGED
|
@@ -1939,7 +1939,8 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
|
|
| 1939 |
# Update x_t based on inference method
|
| 1940 |
if infer_method == "sde":
|
| 1941 |
# Stochastic Differential Equation: predict clean, then re-add noise
|
| 1942 |
-
|
|
|
|
| 1943 |
next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
|
| 1944 |
xt = self.renoise(pred_clean, next_timestep)
|
| 1945 |
elif infer_method == "ode":
|
|
|
|
| 1939 |
# Update x_t based on inference method
|
| 1940 |
if infer_method == "sde":
|
| 1941 |
# Stochastic Differential Equation: predict clean, then re-add noise
|
| 1942 |
+
t_curr_bsz = t_curr * torch.ones((bsz,), device=device, dtype=dtype)
|
| 1943 |
+
pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
|
| 1944 |
next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
|
| 1945 |
xt = self.renoise(pred_clean, next_timestep)
|
| 1946 |
elif infer_method == "ode":
|