manbeast3b commited on
Commit
37e1ccf
·
verified ·
1 Parent(s): 9449493

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +12 -0
src/loss.py CHANGED
@@ -1,6 +1,7 @@
1
  _A=None
2
  import torch
3
  from tqdm import tqdm
 
4
  class LossSchedulerModel(torch.nn.Module):
5
  def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
6
  def forward(A,t,xT,e_prev):
@@ -43,3 +44,14 @@ class SchedulerWrapper:
43
  H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
44
  def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model);print("loaded ",A.loss_params_path)
45
  def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
 
 
 
 
 
 
 
 
 
 
 
 
1
  _A=None
2
  import torch
3
  from tqdm import tqdm
4
+ import output
5
  class LossSchedulerModel(torch.nn.Module):
6
  def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
7
  def forward(A,t,xT,e_prev):
 
44
  H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
45
  def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model);print("loaded ",A.loss_params_path)
46
  def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
47
+ def get_instance(device):
48
+ compress = output.Decoder().to(device).requires_grad_(False)
49
+ compress.load_state_dict(torch.load("xl_compressor.pth", map_location=device, weights_only=True))
50
+ return compress
51
+ @torch.no_grad()
52
+ def hook_pipe(pipe, compress, mul, sub, scaling_factor):
53
+ def compress_machine(magic, *args, **kwargs):
54
+ magic = magic.float().mul(scaling_factor) #pipe.vae.config.scaling_factor
55
+ out_magic = compress(magic).mul_(mul).sub_(sub).cpu()#.mul_(1.2).sub_(0.75).cpu()
56
+ return (out_magic, )
57
+ pipe.vae.decode = compress_machine