Update src/loss.py
Browse files- src/loss.py +2 -2
src/loss.py
CHANGED
|
@@ -96,7 +96,7 @@ class LossScheduler:
|
|
| 96 |
if timestep==A.timesteps[0]: print("resetting"); image_count+=1;count=0;
|
| 97 |
if A.t_prev==-1:A.xT=sample
|
| 98 |
A.e_prev.append(model_output)
|
| 99 |
-
if timestep==
|
| 100 |
with torch.enable_grad():
|
| 101 |
orig_output = torch.load(f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_oig_{image_count}_20.pth")
|
| 102 |
A.optimizer.zero_grad()
|
|
@@ -107,7 +107,7 @@ class LossScheduler:
|
|
| 107 |
if B+1==len(A.timesteps):
|
| 108 |
A.xT=_A;A.e_prev=[];
|
| 109 |
A.t_prev=-1
|
| 110 |
-
A.save("/home/mbhat/
|
| 111 |
else:A.t_prev=B
|
| 112 |
return C,
|
| 113 |
def step_other(self,model_output,timestep,sample,*D,**E):
|
|
|
|
| 96 |
if timestep==A.timesteps[0]: print("resetting"); image_count+=1;count=0;
|
| 97 |
if A.t_prev==-1:A.xT=sample
|
| 98 |
A.e_prev.append(model_output)
|
| 99 |
+
if timestep==A.timesteps[-1]:
|
| 100 |
with torch.enable_grad():
|
| 101 |
orig_output = torch.load(f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_oig_{image_count}_20.pth")
|
| 102 |
A.optimizer.zero_grad()
|
|
|
|
| 107 |
if B+1==len(A.timesteps):
|
| 108 |
A.xT=_A;A.e_prev=[];
|
| 109 |
A.t_prev=-1
|
| 110 |
+
A.save(f"/home/mbhat/weights/weights_latest_{int(timestep)}.pth")
|
| 111 |
else:A.t_prev=B
|
| 112 |
return C,
|
| 113 |
def step_other(self,model_output,timestep,sample,*D,**E):
|