manbeast3b commited on
Commit
ee1557e
·
verified ·
1 Parent(s): 2a75b7c

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +1 -1
src/loss.py CHANGED
@@ -5,7 +5,7 @@ class LossSchedulerModel(torch.nn.Module):
5
  def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
6
  def forward(A,t,xT,e_prev):
7
  B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
8
- for(D,E)in zip(B,A.we[t]):C+=D*E+0.0001
9
  return C.to(xT.dtype)
10
  class LossScheduler:
11
  def __init__(A,timesteps,model):A.timesteps=timesteps;A.model=model;A.init_noise_sigma=1.;A.order=1
 
5
  def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
6
  def forward(A,t,xT,e_prev):
7
  B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
8
+ for(D,E)in zip(B,A.we[t]):C+=D*E
9
  return C.to(xT.dtype)
10
  class LossScheduler:
11
  def __init__(A,timesteps,model):A.timesteps=timesteps;A.model=model;A.init_noise_sigma=1.;A.order=1