rls
Browse files
improve_gainlora/src/cl_trainer_specroute.py
CHANGED
|
@@ -460,7 +460,12 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
|
|
| 460 |
path = os.path.join(last_task_path, "reg_{}.pt".format(i))
|
| 461 |
if os.path.exists(path):
|
| 462 |
mat = torch.load(path, map_location='cpu')
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
mat = torch.nan_to_num(mat, nan=0.0)
|
| 465 |
reg_matrix.append(mat)
|
| 466 |
i += 1
|
|
@@ -489,7 +494,14 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
|
|
| 489 |
if hasattr(module, 'get_feature'):
|
| 490 |
path = os.path.join(base_path, "reg_{}.pt".format(i))
|
| 491 |
mat = torch.load(path, map_location='cpu')
|
| 492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
|
| 494 |
mat = torch.nan_to_num(mat, nan=0.0, posinf=0.0, neginf=0.0)
|
| 495 |
reg_matrix.append(mat)
|
|
|
|
| 460 |
path = os.path.join(last_task_path, "reg_{}.pt".format(i))
|
| 461 |
if os.path.exists(path):
|
| 462 |
mat = torch.load(path, map_location='cpu')
|
| 463 |
+
# Handle both tensor and dict types (dict keys are chunk indices)
|
| 464 |
+
if isinstance(mat, dict):
|
| 465 |
+
for key in mat:
|
| 466 |
+
if torch.isnan(mat[key]).any() or torch.isinf(mat[key]).any():
|
| 467 |
+
mat[key] = torch.nan_to_num(mat[key], nan=0.0)
|
| 468 |
+
elif torch.isnan(mat).any() or torch.isinf(mat).any():
|
| 469 |
mat = torch.nan_to_num(mat, nan=0.0)
|
| 470 |
reg_matrix.append(mat)
|
| 471 |
i += 1
|
|
|
|
| 494 |
if hasattr(module, 'get_feature'):
|
| 495 |
path = os.path.join(base_path, "reg_{}.pt".format(i))
|
| 496 |
mat = torch.load(path, map_location='cpu')
|
| 497 |
+
# Handle both tensor and dict types (dict keys are chunk indices)
|
| 498 |
+
if isinstance(mat, dict):
|
| 499 |
+
has_nan_inf = any(torch.isnan(mat[key]).any() or torch.isinf(mat[key]).any() for key in mat)
|
| 500 |
+
if has_nan_inf:
|
| 501 |
+
print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
|
| 502 |
+
for key in mat:
|
| 503 |
+
mat[key] = torch.nan_to_num(mat[key], nan=0.0, posinf=0.0, neginf=0.0)
|
| 504 |
+
elif torch.isnan(mat).any() or torch.isinf(mat).any():
|
| 505 |
print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
|
| 506 |
mat = torch.nan_to_num(mat, nan=0.0, posinf=0.0, neginf=0.0)
|
| 507 |
reg_matrix.append(mat)
|