manbeast3b commited on
Commit
641932b
·
verified ·
1 Parent(s): 07e13fd

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +2 -2
src/loss.py CHANGED
@@ -81,7 +81,7 @@ class LossScheduler:
81
  def save(A,path):B,C,D,E=A.timesteps,A.model.wx,A.model.we,A.model.scale_factor;torch.save((B,C,D,E),path)
82
  def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
83
  def scale_model_input(A,sample,*B,**C):return sample
84
- def step(self,model_output,timestep,sample,*D,**E):
85
  A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
86
  if A.t_prev==-1:A.xT=sample
87
  A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
@@ -119,7 +119,7 @@ class SchedulerWrapper:
119
  global count
120
  D=sample;E=model_output;A=timestep;
121
  if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0;
122
- if timestep==1: print("starting"); image_count+=1;count=0;
123
  if False:
124
  C=B.scheduler.step(E,A,D,**F);A=A.tolist();
125
  if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
 
81
  def save(A,path):B,C,D,E=A.timesteps,A.model.wx,A.model.we,A.model.scale_factor;torch.save((B,C,D,E),path)
82
  def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
83
  def scale_model_input(A,sample,*B,**C):return sample
84
+ def step(self,model_output,timestep,sample,*D,**E):
85
  A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
86
  if A.t_prev==-1:A.xT=sample
87
  A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
 
119
  global count
120
  D=sample;E=model_output;A=timestep;
121
  if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0;
122
+ if timestep==1: print("starting"); image_count+=1;count=0;
123
  if False:
124
  C=B.scheduler.step(E,A,D,**F);A=A.tolist();
125
  if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]