natmin322 commited on
Commit
f666767
·
1 Parent(s): 7517d8c
improve_gainlora/src/cl_trainer_specroute.py CHANGED
@@ -463,10 +463,13 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
463
  # Handle both tensor and dict types (dict keys are chunk indices)
464
  if isinstance(mat, dict):
465
  for key in mat:
 
466
  if torch.isnan(mat[key]).any() or torch.isinf(mat[key]).any():
467
  mat[key] = torch.nan_to_num(mat[key], nan=0.0)
468
- elif torch.isnan(mat).any() or torch.isinf(mat).any():
469
- mat = torch.nan_to_num(mat, nan=0.0)
 
 
470
  reg_matrix.append(mat)
471
  i += 1
472
  if getattr(self.model.encoder, "routing_mode", "") == "learned":
@@ -500,10 +503,16 @@ class SpecRoute_Trainer(Seq2SeqTrainer):
500
  if has_nan_inf:
501
  print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
502
  for key in mat:
 
503
  mat[key] = torch.nan_to_num(mat[key], nan=0.0, posinf=0.0, neginf=0.0)
504
- elif torch.isnan(mat).any() or torch.isinf(mat).any():
505
- print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
506
- mat = torch.nan_to_num(mat, nan=0.0, posinf=0.0, neginf=0.0)
 
 
 
 
 
507
  reg_matrix.append(mat)
508
  i += 1
509
  if getattr(self.model.encoder, "routing_mode", "") == "learned":
 
463
  # Handle both tensor and dict types (dict keys are chunk indices)
464
  if isinstance(mat, dict):
465
  for key in mat:
466
+ mat[key] = mat[key].to('cuda:0') # Move to GPU
467
  if torch.isnan(mat[key]).any() or torch.isinf(mat[key]).any():
468
  mat[key] = torch.nan_to_num(mat[key], nan=0.0)
469
+ else:
470
+ mat = mat.to('cuda:0') # Move to GPU
471
+ if torch.isnan(mat).any() or torch.isinf(mat).any():
472
+ mat = torch.nan_to_num(mat, nan=0.0)
473
  reg_matrix.append(mat)
474
  i += 1
475
  if getattr(self.model.encoder, "routing_mode", "") == "learned":
 
503
  if has_nan_inf:
504
  print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
505
  for key in mat:
506
+ mat[key] = mat[key].to('cuda:0') # Move to GPU
507
  mat[key] = torch.nan_to_num(mat[key], nan=0.0, posinf=0.0, neginf=0.0)
508
+ else:
509
+ for key in mat:
510
+ mat[key] = mat[key].to('cuda:0') # Move to GPU
511
+ else:
512
+ mat = mat.to('cuda:0') # Move to GPU
513
+ if torch.isnan(mat).any() or torch.isinf(mat).any():
514
+ print(f'[GPM] WARNING: {path} contains NaN/Inf. Cleaning to 0.')
515
+ mat = torch.nan_to_num(mat, nan=0.0, posinf=0.0, neginf=0.0)
516
  reg_matrix.append(mat)
517
  i += 1
518
  if getattr(self.model.encoder, "routing_mode", "") == "learned":