natmin322 commited on
Commit
5299479
·
1 Parent(s): 6a339c3

fix: LoRA reinit, gradient checkpointing use_reentrant=False, checkpoint existence check, collator padding alignment

Browse files
root_gainlora/src/cl_collator.py CHANGED
@@ -112,9 +112,18 @@ class DataCollator:
112
  if self.text_only:
113
  model_inputs = {"inputs": sources, "labels": labels}
114
  else:
 
 
 
 
 
 
 
 
 
115
  model_inputs = self.tokenizer(
116
  sources,
117
- max_length=self.max_source_length,
118
  padding=self.padding,
119
  return_tensors=return_tensors,
120
  truncation=True,
@@ -123,7 +132,7 @@ class DataCollator:
123
  with self.tokenizer.as_target_tokenizer():
124
  labels = self.tokenizer(
125
  labels,
126
- max_length=self.max_target_length,
127
  padding=self.padding,
128
  return_tensors=return_tensors,
129
  truncation=True,
 
112
  if self.text_only:
113
  model_inputs = {"inputs": sources, "labels": labels}
114
  else:
115
+ # Ensure max_length is compatible with pad_to_multiple_of
116
+ _pad_mult = self.pad_to_multiple_of
117
+ _src_len = self.max_source_length
118
+ _tgt_len = self.max_target_length
119
+ if _pad_mult and _src_len and _src_len % _pad_mult != 0:
120
+ _src_len = ((_src_len + _pad_mult - 1) // _pad_mult) * _pad_mult
121
+ if _pad_mult and _tgt_len and _tgt_len % _pad_mult != 0:
122
+ _tgt_len = ((_tgt_len + _pad_mult - 1) // _pad_mult) * _pad_mult
123
+
124
  model_inputs = self.tokenizer(
125
  sources,
126
+ max_length=_src_len,
127
  padding=self.padding,
128
  return_tensors=return_tensors,
129
  truncation=True,
 
132
  with self.tokenizer.as_target_tokenizer():
133
  labels = self.tokenizer(
134
  labels,
135
+ max_length=_tgt_len,
136
  padding=self.padding,
137
  return_tensors=return_tensors,
138
  truncation=True,
root_gainlora/src/run_t5.py CHANGED
@@ -487,22 +487,36 @@ def main():
487
  model.persent = training_args.persent
488
  model.resize_token_embeddings(len(tokenizer))
489
 
 
 
 
 
 
 
 
 
 
 
 
490
  try:
491
  local_rank = int(os.environ['LOCAL_RANK'])
492
  device = torch.device(f"cuda:{local_rank}")
493
  except:
494
  device = torch.device(f"cuda:0")
495
  if model_args.load_checkpoint_from:
496
- print("----------Loading Previous Query Projection Layer----------")
497
- model.encoder.trans_input.load_state_dict(torch.load(model_args.load_checkpoint_from, map_location=device))
498
- if training_args.model_name in ['gainlora_inflora', 'gainlora_olora']:
499
- model.encoder.previous_trans_input.input_linear[0].data.copy_(torch.load(model_args.load_checkpoint_from, map_location=device)['0.weight'])
500
- model.encoder.previous_trans_input.output_linear[0].data.copy_(torch.load(model_args.load_checkpoint_from, map_location=device)['2.weight'])
501
- model.encoder.previous_trans_input.state_dict()
502
- if cur_task_id > 1:
503
- model.encoder.previous_trans_input.input_linear[1:].data.copy_(torch.load(model_args.load_checkpoint_from.replace('trans_input.pt', 'previous_trans_input.pt'), map_location=device)['input_linear'])
504
- model.encoder.previous_trans_input.output_linear[1:].data.copy_(torch.load(model_args.load_checkpoint_from.replace('trans_input.pt', 'previous_trans_input.pt'), map_location=device)['output_linear'])
505
- print("----------Loading Previous Query Projection Layer Done----------")
 
 
 
506
 
507
  if model_args.previous_lora_path:
508
  previous_lora_list = model_args.previous_lora_path.split(',')
@@ -707,7 +721,11 @@ def main():
707
  return result
708
  print(f"-----Gradient checkpointing: {training_args.gradient_checkpointing} -----")
709
  if training_args.gradient_checkpointing:
710
- model.gradient_checkpointing_enable()
 
 
 
 
711
 
712
  world_size = int(os.environ.get("WORLD_SIZE", 1))
713
  training_args.step_per_epoch = math.ceil(len(raw_datasets["train"]) / training_args.per_device_train_batch_size / world_size / training_args.gradient_accumulation_steps)
 
487
  model.persent = training_args.persent
488
  model.resize_token_embeddings(len(tokenizer))
489
 
490
+ # FIX: from_pretrained wraps model construction in no_init_weights() context,
491
+ # which replaces nn.init.kaiming_uniform_ with a no-op. This leaves lora_A
492
+ # as all zeros (from torch.zeros in constructor), making LoRA output = 0
493
+ # and all lora_B gradients = 0. Re-initialize lora_A here.
494
+ _n_reinit = 0
495
+ for _module in model.modules():
496
+ if hasattr(_module, 'lora_A') and hasattr(_module, 'lora_B') and hasattr(_module, 'reset_parameters'):
497
+ nn.init.kaiming_uniform_(_module.lora_A, a=math.sqrt(5))
498
+ _n_reinit += 1
499
+ print(f"[FIX] Re-initialized lora_A in {_n_reinit} LoRA layers with kaiming_uniform_")
500
+
501
  try:
502
  local_rank = int(os.environ['LOCAL_RANK'])
503
  device = torch.device(f"cuda:{local_rank}")
504
  except:
505
  device = torch.device(f"cuda:0")
506
  if model_args.load_checkpoint_from:
507
+ if not os.path.exists(model_args.load_checkpoint_from):
508
+ logger.warning(f"load_checkpoint_from not found: {model_args.load_checkpoint_from}, skipping load")
509
+ else:
510
+ print("----------Loading Previous Query Projection Layer----------")
511
+ model.encoder.trans_input.load_state_dict(torch.load(model_args.load_checkpoint_from, map_location=device))
512
+ if training_args.model_name in ['gainlora_inflora', 'gainlora_olora']:
513
+ model.encoder.previous_trans_input.input_linear[0].data.copy_(torch.load(model_args.load_checkpoint_from, map_location=device)['0.weight'])
514
+ model.encoder.previous_trans_input.output_linear[0].data.copy_(torch.load(model_args.load_checkpoint_from, map_location=device)['2.weight'])
515
+ model.encoder.previous_trans_input.state_dict()
516
+ if cur_task_id > 1:
517
+ model.encoder.previous_trans_input.input_linear[1:].data.copy_(torch.load(model_args.load_checkpoint_from.replace('trans_input.pt', 'previous_trans_input.pt'), map_location=device)['input_linear'])
518
+ model.encoder.previous_trans_input.output_linear[1:].data.copy_(torch.load(model_args.load_checkpoint_from.replace('trans_input.pt', 'previous_trans_input.pt'), map_location=device)['output_linear'])
519
+ print("----------Loading Previous Query Projection Layer Done----------")
520
 
521
  if model_args.previous_lora_path:
522
  previous_lora_list = model_args.previous_lora_path.split(',')
 
721
  return result
722
  print(f"-----Gradient checkpointing: {training_args.gradient_checkpointing} -----")
723
  if training_args.gradient_checkpointing:
724
+ # use_reentrant=False: don't require input requires_grad=True
725
+ # Recommended by PyTorch 2.5+ (will be mandatory in future versions)
726
+ model.gradient_checkpointing_enable(
727
+ gradient_checkpointing_kwargs={"use_reentrant": False}
728
+ )
729
 
730
  world_size = int(os.environ.get("WORLD_SIZE", 1))
731
  training_args.step_per_epoch = math.ceil(len(raw_datasets["train"]) / training_args.per_device_train_batch_size / world_size / training_args.gradient_accumulation_steps)