FabioSarracino commited on
Commit
b2e8955
·
verified ·
1 Parent(s): e432212

Fix SDE inference producing noise when shift != 1.0

Browse files

The SDE inference path computes `next_timestep` using a linear formula:
```
next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
```
This ignores the shift transformation applied to the timestep schedule:
```
t = shift * t / (1 + (shift - 1) * t)
```
When shift != 1.0, the renoise step uses a timestep that doesn't match the shifted schedule the model expects, causing accumulated error and noise-only output.

`t_prev` is already available from the loop iterator and includes the shift transformation. The fix replaces the linear calculation with t_prev:

# Before:
```
next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
xt = self.renoise(pred_clean, next_timestep)
```

# After:
```
xt = self.renoise(pred_clean, t_prev)
```
When shift == 1.0, the behavior is identical to the current code. When shift != 1.0, the schedule is now correct.

Files changed (1) hide show
  1. modeling_acestep_v15_base.py +1 -2
modeling_acestep_v15_base.py CHANGED
@@ -1941,8 +1941,7 @@ class AceStepConditionGenerationModel(AceStepPreTrainedModel):
1941
  # Stochastic Differential Equation: predict clean, then re-add noise
1942
  t_curr_bsz = t_curr * torch.ones((bsz,), device=device, dtype=dtype)
1943
  pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
1944
- next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
1945
- xt = self.renoise(pred_clean, next_timestep)
1946
  elif infer_method == "ode":
1947
  # Ordinary Differential Equation: Euler method
1948
  # dx/dt = -v, so x_{t+1} = x_t - v_t * dt
 
1941
  # Stochastic Differential Equation: predict clean, then re-add noise
1942
  t_curr_bsz = t_curr * torch.ones((bsz,), device=device, dtype=dtype)
1943
  pred_clean = self.get_x0_from_noise(xt, vt, t_curr_bsz)
1944
+ xt = self.renoise(pred_clean, t_prev)
 
1945
  elif infer_method == "ode":
1946
  # Ordinary Differential Equation: Euler method
1947
  # dx/dt = -v, so x_{t+1} = x_t - v_t * dt