manbeast3b commited on
Commit
cdc1358
·
verified ·
1 Parent(s): f34974b

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +2 -2
src/loss.py CHANGED
@@ -96,7 +96,7 @@ class LossScheduler:
96
  if timestep==A.timesteps[0]: print("resetting"); image_count+=1;count=0;
97
  if A.t_prev==-1:A.xT=sample
98
  A.e_prev.append(model_output)
99
- if timestep==B.timesteps[-1]:
100
  with torch.enable_grad():
101
  orig_output = torch.load(f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_oig_{image_count}_20.pth")
102
  A.optimizer.zero_grad()
@@ -107,7 +107,7 @@ class LossScheduler:
107
  if B+1==len(A.timesteps):
108
  A.xT=_A;A.e_prev=[];
109
  A.t_prev=-1
110
- A.save("/home/mbhat/weights_latest.pth")
111
  else:A.t_prev=B
112
  return C,
113
  def step_other(self,model_output,timestep,sample,*D,**E):
 
96
  if timestep==A.timesteps[0]: print("resetting"); image_count+=1;count=0;
97
  if A.t_prev==-1:A.xT=sample
98
  A.e_prev.append(model_output)
99
+ if timestep==A.timesteps[-1]:
100
  with torch.enable_grad():
101
  orig_output = torch.load(f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_oig_{image_count}_20.pth")
102
  A.optimizer.zero_grad()
 
107
  if B+1==len(A.timesteps):
108
  A.xT=_A;A.e_prev=[];
109
  A.t_prev=-1
110
+ A.save(f"/home/mbhat/weights/weights_latest_{int(timestep)}.pth")
111
  else:A.t_prev=B
112
  return C,
113
  def step_other(self,model_output,timestep,sample,*D,**E):