Update src/loss.py
Browse files- src/loss.py +12 -0
src/loss.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
_A=None
|
| 2 |
import torch
|
| 3 |
from tqdm import tqdm
|
|
|
|
| 4 |
class LossSchedulerModel(torch.nn.Module):
|
| 5 |
def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
|
| 6 |
def forward(A,t,xT,e_prev):
|
|
@@ -43,3 +44,14 @@ class SchedulerWrapper:
|
|
| 43 |
H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
|
| 44 |
def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model);print("loaded ",A.loss_params_path)
|
| 45 |
def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
_A=None
|
| 2 |
import torch
|
| 3 |
from tqdm import tqdm
|
| 4 |
+
import output
|
| 5 |
class LossSchedulerModel(torch.nn.Module):
|
| 6 |
def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
|
| 7 |
def forward(A,t,xT,e_prev):
|
|
|
|
| 44 |
H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
|
| 45 |
def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model);print("loaded ",A.loss_params_path)
|
| 46 |
def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
|
| 47 |
+
def get_instance(device):
|
| 48 |
+
compress = output.Decoder().to(device).requires_grad_(False)
|
| 49 |
+
compress.load_state_dict(torch.load("xl_compressor.pth", map_location=device, weights_only=True))
|
| 50 |
+
return compress
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def hook_pipe(pipe, compress, mul, sub, scaling_factor):
|
| 53 |
+
def compress_machine(magic, *args, **kwargs):
|
| 54 |
+
magic = magic.float().mul(scaling_factor) #pipe.vae.config.scaling_factor
|
| 55 |
+
out_magic = compress(magic).mul_(mul).sub_(sub).cpu()#.mul_(1.2).sub_(0.75).cpu()
|
| 56 |
+
return (out_magic, )
|
| 57 |
+
pipe.vae.decode = compress_machine
|