ChuxiJ commited on
Commit
e432212
·
verified ·
1 Parent(s): 76bb727

Update modeling_acestep_v15_base.py

Browse files

fix issue: https://github.com/ace-step/ACE-Step-1.5/issues/214

Files changed (1) hide show
  1. modeling_acestep_v15_base.py +2 -1
modeling_acestep_v15_base.py CHANGED
@@ -1939,7 +1939,8 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1939
  # Update x_t based on inference method
1940
  if infer_method == "sde":
1941
  # Stochastic Differential Equation: predict clean, then re-add noise
1942
- pred_clean = self.get_x0_from_noise(xt, vt, t_curr_tensor)
 
1943
  next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
1944
  xt = self.renoise(pred_clean, next_timestep)
1945
  elif infer_method == "ode":
 
1939
  # Update x_t based on inference method
1940
  if infer_method == "sde":
1941
  # Stochastic Differential Equation: predict clean, then re-add noise
1942
+ t_curr_bsz = t_curr * torch.ones((bsz,), device=device, dtype=dtype)
1943
+ pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
1944
  next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
1945
  xt = self.renoise(pred_clean, next_timestep)
1946
  elif infer_method == "ode":