manbeast3b commited on
Commit
086e3dd
·
verified ·
1 Parent(s): 6ed84fc

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +33 -13
src/loss.py CHANGED
@@ -74,30 +74,53 @@ class LossScheduler:
74
  A.model=model
75
  A.init_noise_sigma=1.
76
  A.order=1
77
- A.optimizer = torch.optim.SGD([A.model.scale_factor, A.model.wx, A.model.we], lr=lr)
78
-
 
79
  @staticmethod
80
  def load(path):A,B,C,E=torch.load(path,map_location='cpu');D=LossSchedulerModel(B,C,scale_factor=E);return LossScheduler(A,D)
81
  def save(A,path):B,C,D,E=A.timesteps,A.model.wx,A.model.we,A.model.scale_factor;torch.save((B,C,D,E),path)
82
  def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
83
  def scale_model_input(A,sample,*B,**C):return sample
84
- def step(self,model_output,timestep,sample,*D,**E):
85
  A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
86
  if A.t_prev==-1:A.xT=sample
87
  A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
88
  if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
89
  else:A.t_prev=B
90
  return C,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def step_other(self,model_output,timestep,sample,*D,**E):
92
  A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1
93
  if A.t_prev==-1:A.xT=sample
94
  A.e_prev.append(model_output)
95
- with torch.enable_grad():
96
- A.optimizer.zero_grad()
97
- C=A.model(B,A.xT,A.e_prev)
98
- L=torch.nn.functional.mse_loss(C, model_output)
99
- L.backward()
100
- A.optimizer.step()
 
101
  if B+1==len(A.timesteps):
102
  A.xT=_A;A.e_prev=[];
103
  A.t_prev=-1
@@ -115,11 +138,8 @@ class SchedulerWrapper:
115
  if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
116
  else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
117
  def step(B,model_output,timestep,sample,**F):
118
- global image_count
119
- global count
120
  D=sample;E=model_output;A=timestep;
121
- if timestep==B.timesteps[0]: print("resetting");
122
- if timestep: print(timestep);
123
  if B.loss_scheduler is _A:
124
  C=B.scheduler.step(E,A,D,**F);A=A.tolist();
125
  if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
 
74
  A.model=model
75
  A.init_noise_sigma=1.
76
  A.order=1
77
+ # A.optimizer = torch.optim.SGD([A.model.scale_factor, A.model.wx, A.model.we], lr=lr)
78
+ A.optimizer = torch.optim.SGD([A.model.wx, A.model.we], lr=lr)
79
+ A.model.train()
80
  @staticmethod
81
  def load(path):A,B,C,E=torch.load(path,map_location='cpu');D=LossSchedulerModel(B,C,scale_factor=E);return LossScheduler(A,D)
82
  def save(A,path):B,C,D,E=A.timesteps,A.model.wx,A.model.we,A.model.scale_factor;torch.save((B,C,D,E),path)
83
  def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
84
  def scale_model_input(A,sample,*B,**C):return sample
85
+ def step_orig(self,model_output,timestep,sample,*D,**E):
86
  A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
87
  if A.t_prev==-1:A.xT=sample
88
  A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
89
  if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
90
  else:A.t_prev=B
91
  return C,
92
+ def step(self,model_output,timestep,sample,*D,**E):
93
+ global image_count
94
+ global count
95
+ if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0;
96
+ A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1
97
+ if A.t_prev==-1:A.xT=sample
98
+ A.e_prev.append(model_output)
99
+ if timestep==B.timesteps[-1]:
100
+ with torch.enable_grad():
101
+ orig_output = torch.load(f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_oig_{image_count}_20.pth")
102
+ A.optimizer.zero_grad()
103
+ C=A.model(B,A.xT,A.e_prev)
104
+ L=torch.nn.functional.mse_loss(C, orig_output)
105
+ L.backward()
106
+ A.optimizer.step()
107
+ if B+1==len(A.timesteps):
108
+ A.xT=_A;A.e_prev=[];
109
+ A.t_prev=-1
110
+ # A.save("/home/mbhat/weights.pth")
111
+ else:A.t_prev=B
112
+ return C,
113
  def step_other(self,model_output,timestep,sample,*D,**E):
114
  A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1
115
  if A.t_prev==-1:A.xT=sample
116
  A.e_prev.append(model_output)
117
+ if timestep==A.timesteps.tolist()[-1]:
118
+ with torch.enable_grad():
119
+ A.optimizer.zero_grad()
120
+ C=A.model(B,A.xT,A.e_prev)
121
+ L=torch.nn.functional.mse_loss(C, model_output)
122
+ L.backward()
123
+ A.optimizer.step()
124
  if B+1==len(A.timesteps):
125
  A.xT=_A;A.e_prev=[];
126
  A.t_prev=-1
 
138
  if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
139
  else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
140
  def step(B,model_output,timestep,sample,**F):
 
 
141
  D=sample;E=model_output;A=timestep;
142
+
 
143
  if B.loss_scheduler is _A:
144
  C=B.scheduler.step(E,A,D,**F);A=A.tolist();
145
  if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]