Update src/loss.py
Browse files- src/loss.py +1 -1
src/loss.py
CHANGED
|
@@ -5,7 +5,7 @@ class LossSchedulerModel(torch.nn.Module):
|
|
| 5 |
def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
|
| 6 |
def forward(A,t,xT,e_prev):
|
| 7 |
B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
|
| 8 |
-
for(D,E)in zip(B,A.we[t]):C+=D*E
|
| 9 |
return C.to(xT.dtype)
|
| 10 |
class LossScheduler:
|
| 11 |
def __init__(A,timesteps,model):A.timesteps=timesteps;A.model=model;A.init_noise_sigma=1.;A.order=1
|
|
|
|
| 5 |
def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
|
| 6 |
def forward(A,t,xT,e_prev):
|
| 7 |
B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
|
| 8 |
+
for(D,E)in zip(B,A.we[t]):C+=D*E
|
| 9 |
return C.to(xT.dtype)
|
| 10 |
class LossScheduler:
|
| 11 |
def __init__(A,timesteps,model):A.timesteps=timesteps;A.model=model;A.init_noise_sigma=1.;A.order=1
|