manbeast3b commited on
Commit
a63adad
·
verified ·
1 Parent(s): bb79c35

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +1 -1
src/loss.py CHANGED
@@ -41,5 +41,5 @@ class SchedulerWrapper:
41
  A=sorted([A for A in C.catch_x],reverse=True);B,D=[],[]
42
  for E in A:F=torch.cat(C.catch_x[E],dim=0);B.append(F);G=torch.cat(C.catch_e[E],dim=0);D.append(G)
43
  H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
44
- def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model)
45
  def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
 
41
  A=sorted([A for A in C.catch_x],reverse=True);B,D=[],[]
42
  for E in A:F=torch.cat(C.catch_x[E],dim=0);B.append(F);G=torch.cat(C.catch_e[E],dim=0);D.append(G)
43
  H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
44
+ def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model);print("loaded ",A.loss_params_path)
45
  def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()