| | _A=None |
| | count=0 |
| | image_count=0 |
| | import torch |
| | from tqdm import tqdm |
| | import numpy as np |
| |
|
| |
|
| | class ExponentialScalingMatrix(torch.nn.Module): |
| | def __init__(self, size): |
| | super().__init__() |
| | self.base = torch.nn.Parameter(torch.ones(size, size)).to('cuda') |
| | def forward(self, timestep): |
| | return self.base * torch.exp(-torch.tensor(timestep) / 10.0).to('cuda') |
| |
|
| | class RecurrentDynamicScaling(torch.nn.Module): |
| | def __init__(self, size): |
| | super().__init__() |
| | self.rnn = torch.nn.GRUCell(size, size) |
| | self.state = torch.nn.Parameter(torch.zeros(size)) |
| |
|
| | def forward(self, timestep, x, correction): |
| | self.state = self.rnn(self.state, correction) |
| | scaled_correction = correction * torch.sigmoid(self.state) |
| | return x + scaled_correction |
| |
|
| | class AttentionDynamicScaling(torch.nn.Module): |
| | def __init__(self, size): |
| | super().__init__() |
| | self.attention = torch.nn.MultiheadAttention(size, num_heads=1) |
| |
|
| | def forward(self, timestep, x, correction): |
| | query = correction.unsqueeze(0) |
| | key = x.unsqueeze(0) |
| | value = correction.unsqueeze(0) |
| | scaled_correction, _ = self.attention(query, key, value) |
| | return x + scaled_correction.squeeze(0) |
| | |
| | class LossSchedulerModel(torch.nn.Module): |
| | def __init__(A,wx,we,scale_factor=None,strategy="exponential_scaling"): |
| | super(LossSchedulerModel,A).__init__() |
| | assert len(wx.shape)==1 and len(we.shape)==2 |
| | B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1] |
| | A.register_parameter('wx',torch.nn.Parameter(wx)) |
| | A.register_parameter('we',torch.nn.Parameter(we)) |
| | size=13 |
| | if type(scale_factor)!=None: |
| | A.register_parameter('scale_factor',torch.nn.Parameter(scale_factor)) |
| | else: |
| | A.scale_factor = torch.nn.Parameter(torch.ones(size)) |
| | A.decay_rate = 0.1 |
| | if strategy == "exponential_scaling": |
| | A.matrix_generator = ExponentialScalingMatrix(size) |
| | else: |
| | raise ValueError("Unknown strategy") |
| |
|
| | def forward_other(A,t,xT,e_prev): |
| | B=e_prev |
| | assert t-len(B)+1==0;C=xT*A.wx[t] |
| | jitter=A.matrix_generator.forward(t) |
| | for(D,E,J)in zip(B,A.we[t],jitter[t]): |
| | dynamic_scale = torch.tanh(A.decay_rate * torch.tensor(t)).to(D.device) |
| | correction=D*(E+J.to(D.device)) |
| | C += correction * dynamic_scale |
| | return C.to(xT.dtype) |
| | def forward(A,t,xT,e_prev): |
| | B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t] |
| | for(D,E)in zip(B,A.we[t]):C+=D*E |
| | return C.to(xT.dtype) |
| |
|
| | class LossScheduler: |
| | def __init__(A,timesteps,model, lr=0.01): |
| | A.timesteps=timesteps |
| | A.model=model |
| | A.init_noise_sigma=1. |
| | A.order=1 |
| | |
| | A.optimizer = torch.optim.SGD([A.model.wx, A.model.we], lr=lr) |
| | A.model.train() |
| | @staticmethod |
| | def load(path):A,B,C,E=torch.load(path,map_location='cpu');D=LossSchedulerModel(B,C,scale_factor=E);return LossScheduler(A,D) |
| | def save(A,path):B,C,D,E=A.timesteps,A.model.wx,A.model.we,A.model.scale_factor;torch.save((B,C,D,E),path) |
| | def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B) |
| | def scale_model_input(A,sample,*B,**C):return sample |
| | def step_orig(self,model_output,timestep,sample,*D,**E): |
| | A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1 |
| | if A.t_prev==-1:A.xT=sample |
| | A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev) |
| | if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1 |
| | else:A.t_prev=B |
| | return C, |
| | def step(self,model_output,timestep,sample,*D,**E): |
| | global image_count |
| | global count |
| | A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1 |
| | if timestep==A.timesteps[0]: print("resetting"); image_count+=1;count=0; |
| | print(timestep) |
| | if A.t_prev==-1:A.xT=sample |
| | A.e_prev.append(model_output) |
| | if timestep==51: |
| | with torch.enable_grad(): |
| | orig_output = torch.load(f"/home/mbhat/edge-maxxing/miner/miner/latents/latent_orig_{image_count}_20.pth") |
| | A.optimizer.zero_grad() |
| | C=A.model(B,A.xT,A.e_prev) |
| | L=torch.nn.functional.mse_loss(C, orig_output) |
| | L.backward() |
| | A.optimizer.step(); print("optimized") |
| | else: |
| | C=A.model(B,A.xT,A.e_prev) |
| | if B+1==len(A.timesteps): |
| | A.xT=_A;A.e_prev=[]; |
| | A.t_prev=-1 |
| | A.save(f"/home/mbhat/weights/weights_latest_{image_count}_{int(timestep)}.pth") |
| | else:A.t_prev=B |
| | return C, |
| | def step_other(self,model_output,timestep,sample,*D,**E): |
| | A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1; count+=1 |
| | if A.t_prev==-1:A.xT=sample |
| | A.e_prev.append(model_output) |
| | if timestep==A.timesteps.tolist()[-1]: |
| | with torch.enable_grad(): |
| | A.optimizer.zero_grad() |
| | C=A.model(B,A.xT,A.e_prev) |
| | L=torch.nn.functional.mse_loss(C, model_output) |
| | L.backward() |
| | A.optimizer.step() |
| | if B+1==len(A.timesteps): |
| | A.xT=_A;A.e_prev=[]; |
| | A.t_prev=-1 |
| | |
| | else:A.t_prev=B |
| | return C, |
| | class SchedulerWrapper: |
| | def __init__(A,scheduler,loss_params_path='/home/mbhat/weights.pth'): |
| | A.scheduler=scheduler |
| | A.catch_x,A.catch_e,A.catch_x_={},{},{} |
| | A.loss_scheduler=_A |
| | A.loss_params_path=loss_params_path |
| | def set_timesteps(A,num_inference_steps,**C): |
| | D=11 |
| | if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(D,**C);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B |
| | else:B=A.loss_scheduler.set_timesteps(D,**C);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B |
| | def step(B,model_output,timestep,sample,**F): |
| | D=sample;E=model_output;A=timestep; |
| | if B.loss_scheduler is _A: |
| | C=B.scheduler.step(E,A,D,**F);A=A.tolist(); |
| | if A not in B.catch_x:B.catch_x[A]=[];B.catch_e[A]=[];B.catch_x_[A]=[] |
| | B.catch_x[A].append(D.clone().detach().cpu());B.catch_e[A].append(E.clone().detach().cpu());B.catch_x_[A].append(C[0].clone().detach().cpu());return C |
| | else:C=B.loss_scheduler.step(E,A,D,**F);return C |
| | def scale_model_input(A,sample,timestep):return sample |
| | def add_noise(A,original_samples,noise,timesteps):B=A.scheduler.add_noise(original_samples,noise,timesteps);return B |
| | def get_path(C): |
| | A=sorted([A for A in C.catch_x],reverse=True);B,D=[],[] |
| | for E in A:F=torch.cat(C.catch_x[E],dim=0);B.append(F);G=torch.cat(C.catch_e[E],dim=0);D.append(G) |
| | H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D |
| | def load_loss_params(A):B,C,D,E=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D,scale_factor=E);A.loss_scheduler=LossScheduler(B,A.loss_model) |
| | def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params() |
| |
|