File size: 6,778 Bytes
24e51a4 2500952 e73d351 2500952 07e13fd e73d351 07e13fd ece2e8f 1edb460 07e13fd e73d351 07e13fd 086e3dd e73d351 07e13fd e73d351 086e3dd 1edb460 e73d351 1edb460 e73d351 086e3dd f34974b cfca40b 086e3dd bb312c6 086e3dd e98e299 086e3dd 9dddae7 ea3e51b 086e3dd cd9e403 086e3dd 07e13fd 086e3dd 07e13fd e73d351 07e13fd e73d351 0e4c4ae e73d351 07e13fd f90950e ce845a0 e73d351 07e13fd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | _A=None
count=0
image_count=0
import torch
from tqdm import tqdm
import numpy as np
class ExponentialScalingMatrix(torch.nn.Module):
def __init__(self, size):
super().__init__()
self.base = torch.nn.Parameter(torch.ones(size, size)).to('cuda')
def forward(self, timestep):
return self.base * torch.exp(-torch.tensor(timestep) / 10.0).to('cuda')
class RecurrentDynamicScaling(torch.nn.Module):
def __init__(self, size):
super().__init__()
self.rnn = torch.nn.GRUCell(size, size)
self.state = torch.nn.Parameter(torch.zeros(size)) # Learnable state
def forward(self, timestep, x, correction):
self.state = self.rnn(self.state, correction) # Gradients flow through GRU
scaled_correction = correction * torch.sigmoid(self.state) # Scale dynamically
return x + scaled_correction
class AttentionDynamicScaling(torch.nn.Module):
def __init__(self, size):
super().__init__()
self.attention = torch.nn.MultiheadAttention(size, num_heads=1)
def forward(self, timestep, x, correction):
query = correction.unsqueeze(0) # Shape (1, batch, size)
key = x.unsqueeze(0)
value = correction.unsqueeze(0)
scaled_correction, _ = self.attention(query, key, value)
return x + scaled_correction.squeeze(0)
class LossSchedulerModel(torch.nn.Module):
def __init__(A,wx,we,scale_factor=None,strategy="exponential_scaling"):
super(LossSchedulerModel,A).__init__()
assert len(wx.shape)==1 and len(we.shape)==2
B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1]
A.register_parameter('wx',torch.nn.Parameter(wx))
A.register_parameter('we',torch.nn.Parameter(we))
size=13
if type(scale_factor)!=None:
A.register_parameter('scale_factor',torch.nn.Parameter(scale_factor))
else:
A.scale_factor = torch.nn.Parameter(torch.ones(size))
A.decay_rate = 0.1
if strategy == "exponential_scaling":
A.matrix_generator = ExponentialScalingMatrix(size)
else:
raise ValueError("Unknown strategy")
def forward_other(A,t,xT,e_prev):
B=e_prev
assert t-len(B)+1==0;C=xT*A.wx[t]
jitter=A.matrix_generator.forward(t)
for(D,E,J)in zip(B,A.we[t],jitter[t]):
dynamic_scale = torch.tanh(A.decay_rate * torch.tensor(t)).to(D.device)
correction=D*(E+J.to(D.device))
C += correction * dynamic_scale
return C.to(xT.dtype) #C.to(xT.dtype)
def forward(A,t,xT,e_prev):
B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
for(D,E)in zip(B,A.we[t]):C+=D*E
return C.to(xT.dtype)
class LossScheduler:
def __init__(A,timesteps,model, lr=0.01):
A.timesteps=timesteps
A.model=model
A.init_noise_sigma=1.
A.order=1
# A.optimizer = torch.optim.SGD([A.model.scale_factor, A.model.wx, A.model.we], lr=lr)
A.optimizer = torch.optim.SGD([A.model.wx, A.model.we], lr=lr)
A.model.train()
@staticmethod
def load(path):A,B,C,E=torch.load(path,map_location='cpu');D=LossSchedulerModel(B,C,scale_factor=E);return LossScheduler(A,D)
def save(A,path):B,C,D,E=A.timesteps,A.model.wx,A.model.we,A.model.scale_factor;torch.save((B,C,D,E),path)
def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
def scale_model_input(A,sample,*B,**C):return sample
def step_orig(self,model_output,timestep,sample,*D,**E):
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
if A.t_prev==-1:A.xT=sample
A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
else:A.t_prev=B
return C,
def step(self,model_output,timestep,sample,*D,**E):
global image_count
global count
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1
if timestep==A.timesteps[0]: print("resetting"); image_count+=1;count=0;
print(timestep)
if A.t_prev==-1:A.xT=sample
A.e_prev.append(model_output)
if timestep==51:
with torch.enable_grad():
orig_output = torch.load(f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_orig_{image_count}_20.pth")
A.optimizer.zero_grad()
C=A.model(B,A.xT,A.e_prev)
L=torch.nn.functional.mse_loss(C, orig_output)
L.backward()
A.optimizer.step(); print("optimized")
else:
C=A.model(B,A.xT,A.e_prev)
if B+1==len(A.timesteps):
A.xT=_A;A.e_prev=[];
A.t_prev=-1
A.save(f"/home/mbhat/weights/weights_latest_{image_count}_{int(timestep)}.pth")
else:A.t_prev=B
return C,
def step_other(self,model_output,timestep,sample,*D,**E):
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1
if A.t_prev==-1:A.xT=sample
A.e_prev.append(model_output)
if timestep==A.timesteps.tolist()[-1]:
with torch.enable_grad():
A.optimizer.zero_grad()
C=A.model(B,A.xT,A.e_prev)
L=torch.nn.functional.mse_loss(C, model_output)
L.backward()
A.optimizer.step()
if B+1==len(A.timesteps):
A.xT=_A;A.e_prev=[];
A.t_prev=-1
# A.save("/home/mbhat/weights.pth")
else:A.t_prev=B
return C,
class SchedulerWrapper:
def __init__(A,scheduler,loss_params_path='/home/mbhat/weights.pth'):
A.scheduler=scheduler
A.catch_x,A.catch_e,A.catch_x_={},{},{}
A.loss_scheduler=_A
A.loss_params_path=loss_params_path
def set_timesteps(A,num_inference_steps,**C):
D=11
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
def step(B,model_output,timestep,sample,**F):
D=sample;E=model_output;A=timestep;
if B.loss_scheduler is _A:
C=B.scheduler.step(E,A,D,**F);A=A.tolist();
if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[]
B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C
else:C=B.loss_scheduler.step(E,A,D,**F);return C
def scale_model_input(A,sample,timestep):return sample
def add_noise(A,original_samples,noise,timesteps):B=A.scheduler.add_noise(original_samples,noise,timesteps);return B
def get_path(C):
A=sorted([A for A in C.catch_x],reverse=True);B,D=[],[]
for E in A:F=torch.cat(C.catch_x[E],dim=0);B.append(F);G=torch.cat(C.catch_e[E],dim=0);D.append(G)
H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
def load_loss_params(A):B,C,D,E=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D,scale_factor=E);A.loss_scheduler=LossScheduler(B,A.loss_model)
def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
|