natmin322 commited on
Commit
7517d8c
·
1 Parent(s): 9be56eb
improve_gainlora/src/cl_trainer_specroute.py CHANGED
@@ -460,7 +460,12 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
460
  path = os.path.join(last_task_path, "reg_{}.pt".format(i))
461
  if os.path.exists(path):
462
  mat = torch.load(path, map_location='cpu')
463
- if torch.isnan(mat).any() or torch.isinf(mat).any():
 
 
 
 
 
464
  mat = torch.nan_to_num(mat, nan=0.0)
465
  reg_matrix.append(mat)
466
  i += 1
@@ -489,7 +494,14 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
489
  if hasattr(module, 'get_feature'):
490
  path = os.path.join(base_path, "reg_{}.pt".format(i))
491
  mat = torch.load(path, map_location='cpu')
492
- if torch.isnan(mat).any() or torch.isinf(mat).any():
 
 
 
 
 
 
 
493
  print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
494
  mat = torch.nan_to_num(mat, nan=0.0, posinf=0.0, neginf=0.0)
495
  reg_matrix.append(mat)
 
460
  path = os.path.join(last_task_path, "reg_{}.pt".format(i))
461
  if os.path.exists(path):
462
  mat = torch.load(path, map_location='cpu')
463
+ # Handle both tensor and dict types (dict keys are chunk indices)
464
+ if isinstance(mat, dict):
465
+ for key in mat:
466
+ if torch.isnan(mat[key]).any() or torch.isinf(mat[key]).any():
467
+ mat[key] = torch.nan_to_num(mat[key], nan=0.0)
468
+ elif torch.isnan(mat).any() or torch.isinf(mat).any():
469
  mat = torch.nan_to_num(mat, nan=0.0)
470
  reg_matrix.append(mat)
471
  i += 1
 
494
  if hasattr(module, 'get_feature'):
495
  path = os.path.join(base_path, "reg_{}.pt".format(i))
496
  mat = torch.load(path, map_location='cpu')
497
+ # Handle both tensor and dict types (dict keys are chunk indices)
498
+ if isinstance(mat, dict):
499
+ has_nan_inf = any(torch.isnan(mat[key]).any() or torch.isinf(mat[key]).any() for key in mat)
500
+ if has_nan_inf:
501
+ print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
502
+ for key in mat:
503
+ mat[key] = torch.nan_to_num(mat[key], nan=0.0, posinf=0.0, neginf=0.0)
504
+ elif torch.isnan(mat).any() or torch.isinf(mat).any():
505
  print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
506
  mat = torch.nan_to_num(mat, nan=0.0, posinf=0.0, neginf=0.0)
507
  reg_matrix.append(mat)