rls
Browse files
improve_gainlora/src/cl_trainer_specroute.py
CHANGED
|
@@ -463,10 +463,13 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
|
|
| 463 |
# Handle both tensor and dict types (dict keys are chunk indices)
|
| 464 |
if isinstance(mat, dict):
|
| 465 |
for key in mat:
|
|
|
|
| 466 |
if torch.isnan(mat[key]).any() or torch.isinf(mat[key]).any():
|
| 467 |
mat[key] = torch.nan_to_num(mat[key], nan=0.0)
|
| 468 |
-
|
| 469 |
-
mat =
|
|
|
|
|
|
|
| 470 |
reg_matrix.append(mat)
|
| 471 |
i += 1
|
| 472 |
if getattr(self.model.encoder, "routing_mode", "") == "learned":
|
|
@@ -500,10 +503,16 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
|
|
| 500 |
if has_nan_inf:
|
| 501 |
print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
|
| 502 |
for key in mat:
|
|
|
|
| 503 |
mat[key] = torch.nan_to_num(mat[key], nan=0.0, posinf=0.0, neginf=0.0)
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
reg_matrix.append(mat)
|
| 508 |
i += 1
|
| 509 |
if getattr(self.model.encoder, "routing_mode", "") == "learned":
|
|
|
|
| 463 |
# Handle both tensor and dict types (dict keys are chunk indices)
|
| 464 |
if isinstance(mat, dict):
|
| 465 |
for key in mat:
|
| 466 |
+
mat[key] = mat[key].to('cuda:0') # Move to GPU
|
| 467 |
if torch.isnan(mat[key]).any() or torch.isinf(mat[key]).any():
|
| 468 |
mat[key] = torch.nan_to_num(mat[key], nan=0.0)
|
| 469 |
+
else:
|
| 470 |
+
mat = mat.to('cuda:0') # Move to GPU
|
| 471 |
+
if torch.isnan(mat).any() or torch.isinf(mat).any():
|
| 472 |
+
mat = torch.nan_to_num(mat, nan=0.0)
|
| 473 |
reg_matrix.append(mat)
|
| 474 |
i += 1
|
| 475 |
if getattr(self.model.encoder, "routing_mode", "") == "learned":
|
|
|
|
| 503 |
if has_nan_inf:
|
| 504 |
print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
|
| 505 |
for key in mat:
|
| 506 |
+
mat[key] = mat[key].to('cuda:0') # Move to GPU
|
| 507 |
mat[key] = torch.nan_to_num(mat[key], nan=0.0, posinf=0.0, neginf=0.0)
|
| 508 |
+
else:
|
| 509 |
+
for key in mat:
|
| 510 |
+
mat[key] = mat[key].to('cuda:0') # Move to GPU
|
| 511 |
+
else:
|
| 512 |
+
mat = mat.to('cuda:0') # Move to GPU
|
| 513 |
+
if torch.isnan(mat).any() or torch.isinf(mat).any():
|
| 514 |
+
print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
|
| 515 |
+
mat = torch.nan_to_num(mat, nan=0.0, posinf=0.0, neginf=0.0)
|
| 516 |
reg_matrix.append(mat)
|
| 517 |
i += 1
|
| 518 |
if getattr(self.model.encoder, "routing_mode", "") == "learned":
|