manbeast3b commited on
Commit
07e13fd
·
verified ·
1 Parent(s): 2bc936c

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +96 -13
src/loss.py CHANGED
@@ -4,29 +4,112 @@ image_count=0
4
  import torch
5
  from tqdm import tqdm
6
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class LossSchedulerModel(torch.nn.Module):
8
- def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def forward(A,t,xT,e_prev):
10
  B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
11
  for(D,E)in zip(B,A.we[t]):C+=D*E
12
  return C.to(xT.dtype)
 
13
  class LossScheduler:
14
- def __init__(A,timesteps,model):A.timesteps=timesteps;A.model=model;A.init_noise_sigma=1.;A.order=1
 
 
 
 
 
 
15
  @staticmethod
16
- def load(path):A,B,C=torch.load(path,map_location='cpu');D=LossSchedulerModel(B,C);return LossScheduler(A,D)
17
- def save(A,path):B,C,D=A.timesteps,A.model.wx,A.model.we;torch.save((B,C,D),path)
18
  def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
19
  def scale_model_input(A,sample,*B,**C):return sample
20
- @torch.no_grad()
21
- def step(self,model_output,timestep,sample,*D,**E):
22
  A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
23
  if A.t_prev==-1:A.xT=sample
24
  A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
25
  if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
26
  else:A.t_prev=B
27
  return C,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class SchedulerWrapper:
29
- def __init__(A,scheduler,loss_params_path='loss_params.pth'):A.scheduler=scheduler;A.catch_x,A.catch_e,A.catch_x_={},{},{};A.loss_scheduler=_A;A.loss_params_path=loss_params_path
 
 
 
 
30
  def set_timesteps(A,num_inference_steps,**C):
31
  D=20
32
  if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
@@ -34,11 +117,11 @@ class SchedulerWrapper:
34
  def step(B,model_output,timestep,sample,**F):
35
  global image_count
36
  global count
37
- D=sample;E=model_output;A=timestep
38
- if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0; torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_orig_{image_count}_0.pth")
39
- if True:
 
40
  C=B.scheduler.step(E,A,D,**F);A=A.tolist();
41
- if timestep==1: torch.save(model_output, f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_orig_{image_count}_20.pth")
42
  if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
43
  B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
44
  else:C=B.loss_scheduler.step(E,A,D,**F);return C
@@ -48,5 +131,5 @@ class SchedulerWrapper:
48
  A=sorted([A for A in C.catch_x],reverse=True);B,D=[],[]
49
  for E in A:F=torch.cat(C.catch_x[E],dim=0);B.append(F);G=torch.cat(C.catch_e[E],dim=0);D.append(G)
50
  H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
51
- def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model)
52
- def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
 
4
  import torch
5
  from tqdm import tqdm
6
  import numpy as np
7
+
8
+
9
+ class ExponentialScalingMatrix(torch.nn.Module):
10
+ def __init__(self, size):
11
+ super().__init__()
12
+ self.base = torch.nn.Parameter(torch.ones(size, size)).to('cuda')
13
+ def forward(self, timestep):
14
+ return self.base * torch.exp(-torch.tensor(timestep) / 10.0).to('cuda')
15
+
16
+ class RecurrentDynamicScaling(torch.nn.Module):
17
+ def __init__(self, size):
18
+ super().__init__()
19
+ self.rnn = torch.nn.GRUCell(size, size)
20
+ self.state = torch.nn.Parameter(torch.zeros(size)) # Learnable state
21
+
22
+ def forward(self, timestep, x, correction):
23
+ self.state = self.rnn(self.state, correction) # Gradients flow through GRU
24
+ scaled_correction = correction * torch.sigmoid(self.state) # Scale dynamically
25
+ return x + scaled_correction
26
+
27
+ class AttentionDynamicScaling(torch.nn.Module):
28
+ def __init__(self, size):
29
+ super().__init__()
30
+ self.attention = torch.nn.MultiheadAttention(size, num_heads=1)
31
+
32
+ def forward(self, timestep, x, correction):
33
+ query = correction.unsqueeze(0) # Shape (1, batch, size)
34
+ key = x.unsqueeze(0)
35
+ value = correction.unsqueeze(0)
36
+ scaled_correction, _ = self.attention(query, key, value)
37
+ return x + scaled_correction.squeeze(0)
38
+
39
  class LossSchedulerModel(torch.nn.Module):
40
+ def __init__(A,wx,we,scale_factor=None,strategy="exponential_scaling"):
41
+ super(LossSchedulerModel,A).__init__()
42
+ assert len(wx.shape)==1 and len(we.shape)==2
43
+ B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1]
44
+ A.register_parameter('wx',torch.nn.Parameter(wx))
45
+ A.register_parameter('we',torch.nn.Parameter(we))
46
+ size=13
47
+ if type(scale_factor)!=None:
48
+ A.register_parameter('scale_factor',torch.nn.Parameter(scale_factor))
49
+ else:
50
+ A.scale_factor = torch.nn.Parameter(torch.ones(size))
51
+ A.decay_rate = 0.1
52
+ if strategy == "exponential_scaling":
53
+ A.matrix_generator = ExponentialScalingMatrix(size)
54
+ else:
55
+ raise ValueError("Unknown strategy")
56
+
57
+ def forward_other(A,t,xT,e_prev):
58
+ B=e_prev
59
+ assert t-len(B)+1==0;C=xT*A.wx[t]
60
+ jitter=A.matrix_generator.forward(t)
61
+ for(D,E,J)in zip(B,A.we[t],jitter[t]):
62
+ dynamic_scale = torch.tanh(A.decay_rate * torch.tensor(t)).to(D.device)
63
+ correction=D*(E+J.to(D.device))
64
+ C += correction * dynamic_scale
65
+ return C.to(xT.dtype) #C.to(xT.dtype)
66
  def forward(A,t,xT,e_prev):
67
  B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
68
  for(D,E)in zip(B,A.we[t]):C+=D*E
69
  return C.to(xT.dtype)
70
+
71
  class LossScheduler:
72
+ def __init__(A,timesteps,model, lr=0.01):
73
+ A.timesteps=timesteps
74
+ A.model=model
75
+ A.init_noise_sigma=1.
76
+ A.order=1
77
+ A.optimizer = torch.optim.SGD([A.model.scale_factor, A.model.wx, A.model.we], lr=lr)
78
+
79
  @staticmethod
80
+ def load(path):A,B,C,E=torch.load(path,map_location='cpu');D=LossSchedulerModel(B,C,scale_factor=E);return LossScheduler(A,D)
81
+ def save(A,path):B,C,D,E=A.timesteps,A.model.wx,A.model.we,A.model.scale_factor;torch.save((B,C,D,E),path)
82
  def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
83
  def scale_model_input(A,sample,*B,**C):return sample
84
+ def step(self,model_output,timestep,sample,*D,**E):
 
85
  A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
86
  if A.t_prev==-1:A.xT=sample
87
  A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
88
  if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
89
  else:A.t_prev=B
90
  return C,
91
+ def step_other(self,model_output,timestep,sample,*D,**E):
92
+ A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1
93
+ if A.t_prev==-1:A.xT=sample
94
+ A.e_prev.append(model_output)
95
+ with torch.enable_grad():
96
+ A.optimizer.zero_grad()
97
+ C=A.model(B,A.xT,A.e_prev)
98
+ L=torch.nn.functional.mse_loss(C, model_output)
99
+ L.backward()
100
+ A.optimizer.step()
101
+ if B+1==len(A.timesteps):
102
+ A.xT=_A;A.e_prev=[];
103
+ A.t_prev=-1
104
+ # A.save("/home/mbhat/weights.pth")
105
+ else:A.t_prev=B
106
+ return C,
107
  class SchedulerWrapper:
108
+ def __init__(A,scheduler,loss_params_path='/home/mbhat/weights.pth'):
109
+ A.scheduler=scheduler
110
+ A.catch_x,A.catch_e,A.catch_x_={},{},{}
111
+ A.loss_scheduler=_A
112
+ A.loss_params_path=loss_params_path
113
  def set_timesteps(A,num_inference_steps,**C):
114
  D=20
115
  if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
 
117
  def step(B,model_output,timestep,sample,**F):
118
  global image_count
119
  global count
120
+ D=sample;E=model_output;A=timestep;
121
+ if timestep==B.timesteps[0]: print("resetting"); image_count+=1;count=0;
122
+ if timestep==1: print("starting"); image_count+=1;count=0;
123
+ if False:
124
  C=B.scheduler.step(E,A,D,**F);A=A.tolist();
 
125
  if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
126
  B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
127
  else:C=B.loss_scheduler.step(E,A,D,**F);return C
 
131
  A=sorted([A for A in C.catch_x],reverse=True);B,D=[],[]
132
  for E in A:F=torch.cat(C.catch_x[E],dim=0);B.append(F);G=torch.cat(C.catch_e[E],dim=0);D.append(G)
133
  H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
134
+ def load_loss_params(A):B,C,D,E=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D,scale_factor=E);A.loss_scheduler=LossScheduler(B,A.loss_model)
135
+ def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()