Update src/loss.py
Browse files- src/loss.py +33 -13
src/loss.py
CHANGED
|
@@ -74,30 +74,53 @@ class LossScheduler:
|
|
| 74 |
A.model=model
|
| 75 |
A.init_noise_sigma=1.
|
| 76 |
A.order=1
|
| 77 |
-
A.optimizer = torch.optim.SGD([A.model.scale_factor, A.model.wx, A.model.we], lr=lr)
|
| 78 |
-
|
|
|
|
| 79 |
@staticmethod
|
| 80 |
def load(path):A,B,C,E=torch.load(path,map_location='cpu');D=LossSchedulerModel(B,C,scale_factor=E);return LossScheduler(A,D)
|
| 81 |
def save(A,path):B,C,D,E=A.timesteps,A.model.wx,A.model.we,A.model.scale_factor;torch.save((B,C,D,E),path)
|
| 82 |
def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
|
| 83 |
def scale_model_input(A,sample,*B,**C):return sample
|
| 84 |
-
def
|
| 85 |
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
|
| 86 |
if A.t_prev==-1:A.xT=sample
|
| 87 |
A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
|
| 88 |
if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
|
| 89 |
else:A.t_prev=B
|
| 90 |
return C,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
def step_other(self,model_output,timestep,sample,*D,**E):
|
| 92 |
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1
|
| 93 |
if A.t_prev==-1:A.xT=sample
|
| 94 |
A.e_prev.append(model_output)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
| 101 |
if B+1==len(A.timesteps):
|
| 102 |
A.xT=_A;A.e_prev=[];
|
| 103 |
A.t_prev=-1
|
|
@@ -115,11 +138,8 @@ class SchedulerWrapper:
|
|
| 115 |
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 116 |
else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 117 |
def step(B,model_output,timestep,sample,**F):
|
| 118 |
-
global image_count
|
| 119 |
-
global count
|
| 120 |
D=sample;E=model_output;A=timestep;
|
| 121 |
-
|
| 122 |
-
if timestep: print(timestep);
|
| 123 |
if B.loss_scheduler is _A:
|
| 124 |
C=B.scheduler.step(E,A,D,**F);A=A.tolist();
|
| 125 |
if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
|
|
|
|
| 74 |
A.model=model
|
| 75 |
A.init_noise_sigma=1.
|
| 76 |
A.order=1
|
| 77 |
+
# A.optimizer = torch.optim.SGD([A.model.scale_factor, A.model.wx, A.model.we], lr=lr)
|
| 78 |
+
A.optimizer = torch.optim.SGD([A.model.wx, A.model.we], lr=lr)
|
| 79 |
+
A.model.train()
|
| 80 |
@staticmethod
|
| 81 |
def load(path):A,B,C,E=torch.load(path,map_location='cpu');D=LossSchedulerModel(B,C,scale_factor=E);return LossScheduler(A,D)
|
| 82 |
def save(A,path):B,C,D,E=A.timesteps,A.model.wx,A.model.we,A.model.scale_factor;torch.save((B,C,D,E),path)
|
| 83 |
def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
|
| 84 |
def scale_model_input(A,sample,*B,**C):return sample
|
| 85 |
+
def step_orig(self,model_output,timestep,sample,*D,**E):
|
| 86 |
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
|
| 87 |
if A.t_prev==-1:A.xT=sample
|
| 88 |
A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
|
| 89 |
if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
|
| 90 |
else:A.t_prev=B
|
| 91 |
return C,
|
| 92 |
+
def step(self,model_output,timestep,sample,*D,**E):
|
| 93 |
+
global image_count
|
| 94 |
+
global count
|
| 95 |
+
if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0;
|
| 96 |
+
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1
|
| 97 |
+
if A.t_prev==-1:A.xT=sample
|
| 98 |
+
A.e_prev.append(model_output)
|
| 99 |
+
if timestep==B.timesteps[-1]:
|
| 100 |
+
with torch.enable_grad():
|
| 101 |
+
orig_output = torch.load(f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_oig_{image_count}_20.pth")
|
| 102 |
+
A.optimizer.zero_grad()
|
| 103 |
+
C=A.model(B,A.xT,A.e_prev)
|
| 104 |
+
L=torch.nn.functional.mse_loss(C, orig_output)
|
| 105 |
+
L.backward()
|
| 106 |
+
A.optimizer.step()
|
| 107 |
+
if B+1==len(A.timesteps):
|
| 108 |
+
A.xT=_A;A.e_prev=[];
|
| 109 |
+
A.t_prev=-1
|
| 110 |
+
# A.save("/home/mbhat/weights.pth")
|
| 111 |
+
else:A.t_prev=B
|
| 112 |
+
return C,
|
| 113 |
def step_other(self,model_output,timestep,sample,*D,**E):
|
| 114 |
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1
|
| 115 |
if A.t_prev==-1:A.xT=sample
|
| 116 |
A.e_prev.append(model_output)
|
| 117 |
+
if timestep==A.timesteps.tolist()[-1]:
|
| 118 |
+
with torch.enable_grad():
|
| 119 |
+
A.optimizer.zero_grad()
|
| 120 |
+
C=A.model(B,A.xT,A.e_prev)
|
| 121 |
+
L=torch.nn.functional.mse_loss(C, model_output)
|
| 122 |
+
L.backward()
|
| 123 |
+
A.optimizer.step()
|
| 124 |
if B+1==len(A.timesteps):
|
| 125 |
A.xT=_A;A.e_prev=[];
|
| 126 |
A.t_prev=-1
|
|
|
|
| 138 |
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 139 |
else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
|
| 140 |
def step(B,model_output,timestep,sample,**F):
|
|
|
|
|
|
|
| 141 |
D=sample;E=model_output;A=timestep;
|
| 142 |
+
|
|
|
|
| 143 |
if B.loss_scheduler is _A:
|
| 144 |
C=B.scheduler.step(E,A,D,**F);A=A.tolist();
|
| 145 |
if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
|