manbeast3b commited on
Commit
ece2e8f
·
verified ·
1 Parent(s): 428c182

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +11 -11
src/loss.py CHANGED
@@ -2,17 +2,17 @@ _A=None
2
  import torch
3
  from tqdm import tqdm
4
  def generate_matrix(size):
5
- matrix = torch.zeros((size, size), dtype=torch.float64)
6
- for i in range(size):
7
- for j in range(i + 1): # Ensure a triangular structure with non-zero values up to the diagonal
8
- matrix[i, j] = -((i + 1) * 0.05 + (j + 1) * 0.005)
9
- return matrix
10
  class LossSchedulerModel(torch.nn.Module):
11
- def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
12
- def forward(A,t,xT,e_prev):
13
- B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
14
- for(D,E)in zip(B,A.we[t]):print(D.shape);print(E.shape);print(C.shape);C+=D*E+(generate_matrix(13))
15
- return C.to(xT.dtype)
16
  class LossScheduler:
17
  def __init__(A,timesteps,model):A.timesteps=timesteps;A.model=model;A.init_noise_sigma=1.;A.order=1
18
  @staticmethod
@@ -22,7 +22,7 @@ class LossScheduler:
22
  def scale_model_input(A,sample,*B,**C):return sample
23
  @torch.no_grad()
24
  def step(self,model_output,timestep,sample,*D,**E):
25
- A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
26
  if A.t_prev==-1:A.xT=sample
27
  A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
28
  if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1
 
2
  import torch
3
  from tqdm import tqdm
4
  def generate_matrix(size):
5
+ matrix = torch.zeros((size, size), dtype=torch.float64)
6
+ for i in range(size):
7
+ for j in range(i + 1): # Ensure a triangular structure with non-zero values up to the diagonal
8
+ matrix[i, j] = -((i + 1) * 0.05 + (j + 1) * 0.005)
9
+ return matrix
10
  class LossSchedulerModel(torch.nn.Module):
11
+ def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
12
+ def forward(A,t,xT,e_prev):
13
+ B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
14
+ for(D,E)in zip(B,A.we[t]):print(D.shape);print(E.shape);print(C.shape);C+=D*E+(generate_matrix(13))
15
+ return C.to(xT.dtype)
16
  class LossScheduler:
17
  def __init__(A,timesteps,model):A.timesteps=timesteps;A.model=model;A.init_noise_sigma=1.;A.order=1
18
  @staticmethod
 
22
  def scale_model_input(A,sample,*B,**C):return sample
23
  @torch.no_grad()
24
  def step(self,model_output,timestep,sample,*D,**E):
25
+ A=self;B=A.timesteps.tolist().index(timestep);assert A.t_prev==-1 or B==A.t_prev+1
26
  if A.t_prev==-1:A.xT=sample
27
  A.e_prev.append(model_output);C=A.model(B,A.xT,A.e_prev)
28
  if B+1==len(A.timesteps):A.xT=_A;A.e_prev=[];A.t_prev=-1