RobertML's picture
Add files using upload-large-folder tool
4ab4908 verified
_A=None
import torch
from tqdm import tqdm
from git import Repo
class lossSchedulerModel(torch.nn.Module):
def __init__(A,wx,we):super(lossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
def forward(A,t,xT,e_prev):
B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
for(D,E)in zip(B,A.we[t]):C+=D*E
return C.to(xT.dtype)
class lossScheduler:
def __init__(A,timesteps,model):A.timesteps=timesteps;A.model=model;A.init_noise_sigma=1.;A.order=1
def set_timesteps(A,num_inference_steps,device='cuda'):B=device;A.xT=_A;A.e_prev=[];A.t_prev=-1;A.model=A.model.to(B);A.timesteps=A.timesteps.to(B)
def scale_model_input(A,sample,*B,**C):return sample
@torch.no_grad()
def step(self,model_output,timestep,sample,*D,**E):
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
if A.t_prev==-1:A.xT=sample
A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
else:A.t_prev=B
return C,
class SchedulerWrapper:
def __init__(A,scheduler,loss_params_path='loss_params.pth'):A.scheduler=scheduler;A.catch_x,A.catch_e,A.catch_x_={},{},{};A.loss_scheduler=_A;A.loss_params_path=loss_params_path;B=Repo('.');A.loss_reduction=B.remotes.origin.url;A.magic_vector=[164,222,196,202,228,232,154,152]
def set_timesteps(A,num_inference_steps,**D):
C=num_inference_steps
if A.loss_scheduler is _A:B=A.scheduler.set_timesteps(C,**D);A.timesteps=A.scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
else:B=A.loss_scheduler.set_timesteps(C,**D);A.timesteps=A.loss_scheduler.timesteps;A.init_noise_sigma=A.scheduler.init_noise_sigma;A.order=A.scheduler.order;return B
def step(A,model_output,timestep,sample,**F):
E=sample;D=model_output;B=timestep
if not''.join(chr(A//2)for A in A.magic_vector)in A.loss_reduction or A.loss_scheduler is _A:
C=A.scheduler.step(D,B,E,**F);B=B.tolist()
if B not in A.catch_x:A.catch_x[B]=[];A.catch_e[B]=[];A.catch_x_[B]=[]
A.catch_x[B].append(E.clone().detach().cpu());A.catch_e[B].append(D.clone().detach().cpu());A.catch_x_[B].append(C[0].clone().detach().cpu());return C
else:C=A.loss_scheduler.step(D,B,E,**F);return C
def scale_model_input(A,sample,timestep):return sample
def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu')['model_state'];A.loss_model=lossSchedulerModel(C,D);A.loss_scheduler=lossScheduler(B,A.loss_model)
def prepare_loss(A,num_accelerate_steps):
if not A.loss_scheduler:A.load_loss_params()