Update src/loss.py
Browse files- src/loss.py +3 -3
src/loss.py
CHANGED
|
@@ -28,17 +28,17 @@ class LossScheduler:
|
|
| 28 |
class SchedulerWrapper:
|
| 29 |
def __init__(A,scheduler,loss_params_path='loss_params.pth'):A.scheduler=scheduler;A.catch_x,A.catch_e,A.catch_x_={},{},{};A.loss_scheduler=_A;A.loss_params_path=loss_params_path
|
| 30 |
def set_timesteps(A,num_inference_steps,**C):
|
| 31 |
-
D=
|
| 32 |
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 33 |
else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 34 |
def step(B,model_output,timestep,sample,**F):
|
| 35 |
global image_count
|
| 36 |
global count
|
| 37 |
D=sample;E=model_output;A=timestep
|
| 38 |
-
if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0; torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/
|
| 39 |
if True:
|
| 40 |
C=B.scheduler.step(E,A,D,**F);A=A.tolist();
|
| 41 |
-
if timestep==1: torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/
|
| 42 |
if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
|
| 43 |
B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
|
| 44 |
else:C=B.loss_scheduler.step(E,A,D,**F);return C
|
|
|
|
| 28 |
class SchedulerWrapper:
|
| 29 |
def __init__(A,scheduler,loss_params_path='loss_params.pth'):A.scheduler=scheduler;A.catch_x,A.catch_e,A.catch_x_={},{},{};A.loss_scheduler=_A;A.loss_params_path=loss_params_path
|
| 30 |
def set_timesteps(A,num_inference_steps,**C):
|
| 31 |
+
D=20
|
| 32 |
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 33 |
else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 34 |
def step(B,model_output,timestep,sample,**F):
|
| 35 |
global image_count
|
| 36 |
global count
|
| 37 |
D=sample;E=model_output;A=timestep
|
| 38 |
+
if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0; torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_orig_{image_count}_0.pth")
|
| 39 |
if True:
|
| 40 |
C=B.scheduler.step(E,A,D,**F);A=A.tolist();
|
| 41 |
+
if timestep==1: torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_orig_{image_count}_20.pth")
|
| 42 |
if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
|
| 43 |
B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
|
| 44 |
else:C=B.loss_scheduler.step(E,A,D,**F);return C
|