manbeast3b commited on
Commit
d198bf3
·
verified ·
1 Parent(s): a5650d6

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +12 -0
src/loss.py CHANGED
@@ -1,6 +1,7 @@
1
  _A=None
2
  import torch
3
  from tqdm import tqdm
 
4
  class LossSchedulerModel(torch.nn.Module):
5
  def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
6
  def forward(A,t,xT,e_prev):
@@ -44,3 +45,14 @@ class SchedulerWrapper:
44
  def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model)
45
  def load_loss_params_path(A, path):B,C,D=torch.load(path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model)
46
  def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
 
 
 
 
 
 
 
 
 
 
 
 
1
  _A=None
2
  import torch
3
  from tqdm import tqdm
4
+ import output
5
  class LossSchedulerModel(torch.nn.Module):
6
  def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
7
  def forward(A,t,xT,e_prev):
 
45
  def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model)
46
  def load_loss_params_path(A, path):B,C,D=torch.load(path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model)
47
  def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
48
+ def get_instance(device):
49
+ compress = output.Decoder().to(device).requires_grad_(False)
50
+ compress.load_state_dict(torch.load("taesdxl_decoder.pth", map_location=device, weights_only=True)) # taesdxl_decoder.pth
51
+ return compress
52
+ @torch.no_grad()
53
+ def hook_pipe(pipe, compress, mul, sub, scaling_factor):
54
+ def compress_machine(magic, *args, **kwargs):
55
+ magic = magic.float().mul(scaling_factor) #pipe.vae.config.scaling_factor
56
+ out_magic = compress(magic).mul_(mul).sub_(sub).cpu()#.mul_(1.2).sub_(0.75).cpu()
57
+ return (out_magic, )
58
+ pipe.vae.decode = compress_machine