manbeast3b commited on
Commit
f90950e
·
verified ·
1 Parent(s): d82e44a

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +2 -2
src/loss.py CHANGED
@@ -111,7 +111,7 @@ class SchedulerWrapper:
111
  A.loss_scheduler=_A
112
  A.loss_params_path=loss_params_path
113
  def set_timesteps(A,num_inference_steps,**C):
114
- D=20
115
  if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
116
  else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
117
  def step(B,model_output,timestep,sample,**F):
@@ -120,7 +120,7 @@ class SchedulerWrapper:
120
  D=sample;E=model_output;A=timestep;
121
  if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0;
122
  if timestep==1: print("starting"); image_count+=1;count=0;
123
- if False:
124
  C=B.scheduler.step(E,A,D,**F);A=A.tolist();
125
  if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
126
  B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
 
111
  A.loss_scheduler=_A
112
  A.loss_params_path=loss_params_path
113
  def set_timesteps(A,num_inference_steps,**C):
114
+ D=num_inference_steps
115
  if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
116
  else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
117
  def step(B,model_output,timestep,sample,**F):
 
120
  D=sample;E=model_output;A=timestep;
121
  if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0;
122
  if timestep==1: print("starting"); image_count+=1;count=0;
123
+ if B.loss_scheduler is _A:
124
  C=B.scheduler.step(E,A,D,**F);A=A.tolist();
125
  if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
126
  B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C