manbeast3b commited on
Commit
9612d05
·
verified ·
1 Parent(s): 38011a8

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +1 -1
src/loss.py CHANGED
@@ -22,7 +22,7 @@ class LossSchedulerModel(torch.nn.Module):
22
  """
23
  batch_size, num_channels, height, width = xT.shape
24
  # num_channels = 1
25
- num_prev_steps = e_prev.shape[0]
26
 
27
  # Ensure inputs match expected dimensions
28
  assert timestep - num_prev_steps + 1 == 0, "Mismatch between timestep and e_prev length"
 
22
  """
23
  batch_size, num_channels, height, width = xT.shape
24
  # num_channels = 1
25
+ num_prev_steps = len(e_prev)
26
 
27
  # Ensure inputs match expected dimensions
28
  assert timestep - num_prev_steps + 1 == 0, "Mismatch between timestep and e_prev length"