manbeast3b commited on
Commit
ea3e51b
·
verified ·
1 Parent(s): cdc1358

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +2 -0
src/loss.py CHANGED
@@ -104,6 +104,8 @@ class LossScheduler:
104
  L=torch.nn.functional.mse_loss(C, orig_output)
105
  L.backward()
106
  A.optimizer.step(); print("optimized")
 
 
107
  if B+1==len(A.timesteps):
108
  A.xT=_A;A.e_prev=[];
109
  A.t_prev=-1
 
104
  L=torch.nn.functional.mse_loss(C, orig_output)
105
  L.backward()
106
  A.optimizer.step(); print("optimized")
107
+ else:
108
+ C=A.model(B,A.xT,A.e_prev)
109
  if B+1==len(A.timesteps):
110
  A.xT=_A;A.e_prev=[];
111
  A.t_prev=-1