Update src/loss.py
Browse files- src/loss.py +2 -0
src/loss.py
CHANGED
|
@@ -104,6 +104,8 @@ class LossScheduler:
|
|
| 104 |
L=torch.nn.functional.mse_loss(C, orig_output)
|
| 105 |
L.backward()
|
| 106 |
A.optimizer.step(); print("optimized")
|
|
|
|
|
|
|
| 107 |
if B+1==len(A.timesteps):
|
| 108 |
A.xT=_A;A.e_prev=[];
|
| 109 |
A.t_prev=-1
|
|
|
|
| 104 |
L=torch.nn.functional.mse_loss(C, orig_output)
|
| 105 |
L.backward()
|
| 106 |
A.optimizer.step(); print("optimized")
|
| 107 |
+
else:
|
| 108 |
+
C=A.model(B,A.xT,A.e_prev)
|
| 109 |
if B+1==len(A.timesteps):
|
| 110 |
A.xT=_A;A.e_prev=[];
|
| 111 |
A.t_prev=-1
|