Update src/loss.py
Browse files- src/loss.py +1 -1
src/loss.py
CHANGED
|
@@ -41,5 +41,5 @@ class SchedulerWrapper:
|
|
| 41 |
A=sorted([A for A in C.catch_x],reverse=True);B,D=[],[]
|
| 42 |
for E in A:F=torch.cat(C.catch_x[E],dim=0);B.append(F);G=torch.cat(C.catch_e[E],dim=0);D.append(G)
|
| 43 |
H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
|
| 44 |
-
def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model)
|
| 45 |
def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
|
|
|
|
| 41 |
A=sorted([A for A in C.catch_x],reverse=True);B,D=[],[]
|
| 42 |
for E in A:F=torch.cat(C.catch_x[E],dim=0);B.append(F);G=torch.cat(C.catch_e[E],dim=0);D.append(G)
|
| 43 |
H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
|
| 44 |
+
def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model);print("loaded ",A.loss_params_path)
|
| 45 |
def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
|