Fix SDE inference: batch size mismatch with CFG and incorrect timestep schedule

#1

Problem

The SDE inference path in generate_audio has two bugs that make it unusable with the SFT models:

Bug 1: Batch size mismatch with CFG (SFT model only)

In acestep-v15-sft, the SDE path uses t_curr_tensor which has shape [x.shape[0]] (CFG-doubled, e.g. 2), but get_x0_from_noise operates on xt which has shape [bsz] (actual batch size, e.g. 1). The broadcasting of [2, 1, 1] against [1, seq, feat] causes xt to become batch=2. On the next iteration, torch.cat([xt, xt], dim=0) produces batch=4, while context_latents remains batch=2, causing a RuntimeError.

Note: this was already fixed in acestep-v15-base by using t_curr_bsz, but acestep-v15-sft still uses t_curr_tensor.

Bug 2: Linear timestep ignores shift schedule

next_timestep is computed as:

next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)

This is a linear schedule that ignores the shift transformation applied earlier:

t = shift * t / (1 + (shift - 1) * t)

When shift != 1.0, the noise level in renoise doesn't match the shifted schedule the model expects, causing accumulated error and noise-only output.

t_prev is already available from the loop iterator and includes the shift transformation.

Fix

Current:

if infer_method == "sde":
    # Stochastic Differential Equation: predict clean, then re-add noise
    pred_clean = self.get_x0_from_noise(xt, vt, t_curr_tensor)
    next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
    xt = self.renoise(pred_clean, next_timestep)

Fixed:

if infer_method == "sde":
    # Stochastic Differential Equation: predict clean, then re-add noise
    t_curr_bsz = t_curr * torch.ones((bsz,), device=device, dtype=dtype)
    pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
    xt = self.renoise(pred_clean, t_prev)

Notes

  • When shift == 1.0, t_prev equals the linear calculation, so behavior is identical.
  • When shift != 1.0, the schedule is now correct.
ACE-Step org

Please solve conflicts. It seems that I already fix it

Hi @ChuxiJ , conflicts resolved! I saw that in the meantime you already fixed Bug 1. My remaining change fixes Bug 2 (the timestep schedule ignoring the shift transformation), by replacing the linear next_timestep with t_prev. Ready to merge.

Thank you,
Fabio

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment