Fix SDE inference: batch size mismatch with CFG and incorrect timestep schedule
Problem
The SDE inference path in generate_audio has two bugs that make it unusable with the SFT models:
Bug 1: Batch size mismatch with CFG (SFT model only)
In acestep-v15-sft, the SDE path uses t_curr_tensor which has shape [x.shape[0]] (CFG-doubled, e.g. 2), but get_x0_from_noise operates on xt which has shape [bsz] (actual batch size, e.g. 1). The broadcasting of [2, 1, 1] against [1, seq, feat] causes xt to become batch=2. On the next iteration, torch.cat([xt, xt], dim=0) produces batch=4, while context_latents remains batch=2, causing a RuntimeError.
Note: this was already fixed in acestep-v15-base by using t_curr_bsz, but acestep-v15-sft still uses t_curr_tensor.
Bug 2: Linear timestep ignores shift schedule
next_timestep is computed as:
next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
This is a linear schedule that ignores the shift transformation applied earlier:
t = shift * t / (1 + (shift - 1) * t)
When shift != 1.0, the noise level in renoise doesn't match the shifted schedule the model expects, causing accumulated error and noise-only output.
t_prev is already available from the loop iterator and includes the shift transformation.
Fix
Current:
if infer_method == "sde":
# Stochastic Differential Equation: predict clean, then re-add noise
pred_clean = self.get_x0_from_noise(xt, vt, t_curr_tensor)
next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
xt = self.renoise(pred_clean, next_timestep)
Fixed:
if infer_method == "sde":
# Stochastic Differential Equation: predict clean, then re-add noise
t_curr_bsz = t_curr * torch.ones((bsz,), device=device, dtype=dtype)
pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
xt = self.renoise(pred_clean, t_prev)
Notes
- When
shift == 1.0,t_prevequals the linear calculation, so behavior is identical. - When
shift != 1.0, the schedule is now correct.
Please solve conflicts. It seems that I already fix it
Hi @ChuxiJ , conflicts resolved! I saw that in the meantime you already fixed Bug 1. My remaining change fixes Bug 2 (the timestep schedule ignoring the shift transformation), by replacing the linear next_timestep with t_prev. Ready to merge.
Thank you,
Fabio