Update src/loss.py
Browse files- src/loss.py +6 -4
src/loss.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
|
|
| 1 |
count=0
|
| 2 |
image_count=0
|
| 3 |
-
_A=None
|
| 4 |
import torch
|
| 5 |
from tqdm import tqdm
|
| 6 |
import numpy as np
|
|
@@ -107,11 +107,13 @@ class SchedulerWrapper:
|
|
| 107 |
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 108 |
else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 109 |
def step(B,model_output,timestep,sample,**F):
|
|
|
|
|
|
|
| 110 |
D=sample;E=model_output;A=timestep;
|
| 111 |
-
if timestep==
|
| 112 |
if True:
|
| 113 |
-
C=B.scheduler.step(E,A,D,**F);A=A.tolist();
|
| 114 |
-
if timestep
|
| 115 |
if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
|
| 116 |
B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
|
| 117 |
else:C=B.loss_scheduler.step(E,A,D,**F);return C
|
|
|
|
| 1 |
+
_A=None
|
| 2 |
count=0
|
| 3 |
image_count=0
|
|
|
|
| 4 |
import torch
|
| 5 |
from tqdm import tqdm
|
| 6 |
import numpy as np
|
|
|
|
| 107 |
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 108 |
else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 109 |
def step(B,model_output,timestep,sample,**F):
|
| 110 |
+
global image_count
|
| 111 |
+
global count
|
| 112 |
D=sample;E=model_output;A=timestep;
|
| 113 |
+
if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0
|
| 114 |
if True:
|
| 115 |
+
C=B.scheduler.step(E,A,D,**F);A=A.tolist();
|
| 116 |
+
if timestep==1: torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_{count}.pth")
|
| 117 |
if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
|
| 118 |
B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
|
| 119 |
else:C=B.loss_scheduler.step(E,A,D,**F);return C
|