Update src/loss.py
Browse files- src/loss.py +11 -11
src/loss.py
CHANGED
|
@@ -2,17 +2,17 @@ _A=None
|
|
| 2 |
import torch
|
| 3 |
from tqdm import tqdm
|
| 4 |
def generate_matrix(size):
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
class LossSchedulerModel(torch.nn.Module):
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
class LossScheduler:
|
| 17 |
def __init__(A,timesteps,model):A.timesteps=timesteps;A.model=model;A.init_noise_sigma=1.;A.order=1
|
| 18 |
@staticmethod
|
|
@@ -22,7 +22,7 @@ class LossScheduler:
|
|
| 22 |
def scale_model_input(A,sample,*B,**C):return sample
|
| 23 |
@torch.no_grad()
|
| 24 |
def step(self,model_output,timestep,sample,*D,**E):
|
| 25 |
-
|
| 26 |
if A.t_prev==-1:A.xT=sample
|
| 27 |
A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
|
| 28 |
if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
|
|
|
|
| 2 |
import torch
|
| 3 |
from tqdm import tqdm
|
| 4 |
def generate_matrix(size):
|
| 5 |
+
matrix = torch.zeros((size, size), dtype=torch.float64)
|
| 6 |
+
for i in range(size):
|
| 7 |
+
for j in range(i + 1): # Ensure a triangular structure with non-zero values up to the diagonal
|
| 8 |
+
matrix[i, j] = -((i + 1) * 0.05 + (j + 1) * 0.005)
|
| 9 |
+
return matrix
|
| 10 |
class LossSchedulerModel(torch.nn.Module):
|
| 11 |
+
def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
|
| 12 |
+
def forward(A,t,xT,e_prev):
|
| 13 |
+
B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
|
| 14 |
+
for(D,E)in zip(B,A.we[t]):print(D.shape);print(E.shape);print(C.shape);C+=D*E+(generate_matrix(13))
|
| 15 |
+
return C.to(xT.dtype)
|
| 16 |
class LossScheduler:
|
| 17 |
def __init__(A,timesteps,model):A.timesteps=timesteps;A.model=model;A.init_noise_sigma=1.;A.order=1
|
| 18 |
@staticmethod
|
|
|
|
| 22 |
def scale_model_input(A,sample,*B,**C):return sample
|
| 23 |
@torch.no_grad()
|
| 24 |
def step(self,model_output,timestep,sample,*D,**E):
|
| 25 |
+
A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
|
| 26 |
if A.t_prev==-1:A.xT=sample
|
| 27 |
A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
|
| 28 |
if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
|