Update src/loss.py
Browse files- src/loss.py +1 -1
src/loss.py
CHANGED
|
@@ -22,7 +22,7 @@ class LossSchedulerModel(torch.nn.Module):
|
|
| 22 |
"""
|
| 23 |
batch_size, num_channels, height, width = xT.shape
|
| 24 |
# num_channels = 1
|
| 25 |
-
num_prev_steps = e_prev
|
| 26 |
|
| 27 |
# Ensure inputs match expected dimensions
|
| 28 |
assert timestep - num_prev_steps + 1 == 0, "Mismatch between timestep and e_prev length"
|
|
|
|
| 22 |
"""
|
| 23 |
batch_size, num_channels, height, width = xT.shape
|
| 24 |
# num_channels = 1
|
| 25 |
+
num_prev_steps = len(e_prev)
|
| 26 |
|
| 27 |
# Ensure inputs match expected dimensions
|
| 28 |
assert timestep - num_prev_steps + 1 == 0, "Mismatch between timestep and e_prev length"
|