Update src/loss.py
Browse files- src/loss.py +2 -2
src/loss.py
CHANGED
|
@@ -111,7 +111,7 @@ class SchedulerWrapper:
|
|
| 111 |
A.loss_scheduler=_A
|
| 112 |
A.loss_params_path=loss_params_path
|
| 113 |
def set_timesteps(A,num_inference_steps,**C):
|
| 114 |
-
D=
|
| 115 |
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 116 |
else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 117 |
def step(B,model_output,timestep,sample,**F):
|
|
@@ -120,7 +120,7 @@ class SchedulerWrapper:
|
|
| 120 |
D=sample;E=model_output;A=timestep;
|
| 121 |
if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0;
|
| 122 |
if timestep==1: print("starting"); image_count+=1;count=0;
|
| 123 |
-
if
|
| 124 |
C=B.scheduler.step(E,A,D,**F);A=A.tolist();
|
| 125 |
if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
|
| 126 |
B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
|
|
|
|
| 111 |
A.loss_scheduler=_A
|
| 112 |
A.loss_params_path=loss_params_path
|
| 113 |
def set_timesteps(A,num_inference_steps,**C):
|
| 114 |
+
D=num_inference_steps
|
| 115 |
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 116 |
else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 117 |
def step(B,model_output,timestep,sample,**F):
|
|
|
|
| 120 |
D=sample;E=model_output;A=timestep;
|
| 121 |
if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0;
|
| 122 |
if timestep==1: print("starting"); image_count+=1;count=0;
|
| 123 |
+
if B.loss_scheduler is _A:
|
| 124 |
C=B.scheduler.step(E,A,D,**F);A=A.tolist();
|
| 125 |
if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
|
| 126 |
B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
|