Update src/loss.py
Browse files- src/loss.py +2 -2
src/loss.py
CHANGED
|
@@ -81,7 +81,7 @@ class LossScheduler:
|
|
| 81 |
def save(A,path):B,C,D,E=A.timesteps,A.model.wx,A.model.we,A.model.scale_factor;torch.save((B,C,D,E),path)
|
| 82 |
def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
|
| 83 |
def scale_model_input(A,sample,*B,**C):return sample
|
| 84 |
-
|
| 85 |
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
|
| 86 |
if A.t_prev==-1:A.xT=sample
|
| 87 |
A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
|
|
@@ -119,7 +119,7 @@ class SchedulerWrapper:
|
|
| 119 |
global count
|
| 120 |
D=sample;E=model_output;A=timestep;
|
| 121 |
if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0;
|
| 122 |
-
|
| 123 |
if False:
|
| 124 |
C=B.scheduler.step(E,A,D,**F);A=A.tolist();
|
| 125 |
if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
|
|
|
|
| 81 |
def save(A,path):B,C,D,E=A.timesteps,A.model.wx,A.model.we,A.model.scale_factor;torch.save((B,C,D,E),path)
|
| 82 |
def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
|
| 83 |
def scale_model_input(A,sample,*B,**C):return sample
|
| 84 |
+
def step(self,model_output,timestep,sample,*D,**E):
|
| 85 |
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
|
| 86 |
if A.t_prev==-1:A.xT=sample
|
| 87 |
A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
|
|
|
|
| 119 |
global count
|
| 120 |
D=sample;E=model_output;A=timestep;
|
| 121 |
if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0;
|
| 122 |
+
if timestep==1: print("starting"); image_count+=1;count=0;
|
| 123 |
if False:
|
| 124 |
C=B.scheduler.step(E,A,D,**F);A=A.tolist();
|
| 125 |
if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
|