fix: LoRA reinit, gradient checkpointing use_reentrant=False, checkpoint existence check, collator padding alignment
Browse files- root_gainlora/src/cl_collator.py +11 -2
- root_gainlora/src/run_t5.py +29 -11
root_gainlora/src/cl_collator.py
CHANGED
|
@@ -112,9 +112,18 @@ class DataCollator:
|
|
| 112 |
if self.text_only:
|
| 113 |
model_inputs = {"inputs": sources, "labels": labels}
|
| 114 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
model_inputs = self.tokenizer(
|
| 116 |
sources,
|
| 117 |
-
max_length=
|
| 118 |
padding=self.padding,
|
| 119 |
return_tensors=return_tensors,
|
| 120 |
truncation=True,
|
|
@@ -123,7 +132,7 @@ class DataCollator:
|
|
| 123 |
with self.tokenizer.as_target_tokenizer():
|
| 124 |
labels = self.tokenizer(
|
| 125 |
labels,
|
| 126 |
-
max_length=
|
| 127 |
padding=self.padding,
|
| 128 |
return_tensors=return_tensors,
|
| 129 |
truncation=True,
|
|
|
|
| 112 |
if self.text_only:
|
| 113 |
model_inputs = {"inputs": sources, "labels": labels}
|
| 114 |
else:
|
| 115 |
+
# Ensure max_length is compatible with pad_to_multiple_of
|
| 116 |
+
_pad_mult = self.pad_to_multiple_of
|
| 117 |
+
_src_len = self.max_source_length
|
| 118 |
+
_tgt_len = self.max_target_length
|
| 119 |
+
if _pad_mult and _src_len and _src_len % _pad_mult != 0:
|
| 120 |
+
_src_len = ((_src_len + _pad_mult - 1) // _pad_mult) * _pad_mult
|
| 121 |
+
if _pad_mult and _tgt_len and _tgt_len % _pad_mult != 0:
|
| 122 |
+
_tgt_len = ((_tgt_len + _pad_mult - 1) // _pad_mult) * _pad_mult
|
| 123 |
+
|
| 124 |
model_inputs = self.tokenizer(
|
| 125 |
sources,
|
| 126 |
+
max_length=_src_len,
|
| 127 |
padding=self.padding,
|
| 128 |
return_tensors=return_tensors,
|
| 129 |
truncation=True,
|
|
|
|
| 132 |
with self.tokenizer.as_target_tokenizer():
|
| 133 |
labels = self.tokenizer(
|
| 134 |
labels,
|
| 135 |
+
max_length=_tgt_len,
|
| 136 |
padding=self.padding,
|
| 137 |
return_tensors=return_tensors,
|
| 138 |
truncation=True,
|
root_gainlora/src/run_t5.py
CHANGED
|
@@ -487,22 +487,36 @@ def main():
|
|
| 487 |
model.persent = training_args.persent
|
| 488 |
model.resize_token_embeddings(len(tokenizer))
|
| 489 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
try:
|
| 491 |
local_rank = int(os.environ['LOCAL_RANK'])
|
| 492 |
device = torch.device(f"cuda:{local_rank}")
|
| 493 |
except:
|
| 494 |
device = torch.device(f"cuda:0")
|
| 495 |
if model_args.load_checkpoint_from:
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
model.encoder.
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
model.encoder.previous_trans_input.
|
| 504 |
-
model.encoder.previous_trans_input.
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
if model_args.previous_lora_path:
|
| 508 |
previous_lora_list = model_args.previous_lora_path.split(',')
|
|
@@ -707,7 +721,11 @@ def main():
|
|
| 707 |
return result
|
| 708 |
print(f"-----Gradient checkpointing: {training_args.gradient_checkpointing} -----")
|
| 709 |
if training_args.gradient_checkpointing:
|
| 710 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
|
| 712 |
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 713 |
training_args.step_per_epoch = math.ceil(len(raw_datasets["train"]) / training_args.per_device_train_batch_size / world_size / training_args.gradient_accumulation_steps)
|
|
|
|
| 487 |
model.persent = training_args.persent
|
| 488 |
model.resize_token_embeddings(len(tokenizer))
|
| 489 |
|
| 490 |
+
# FIX: from_pretrained wraps model construction in no_init_weights() context,
|
| 491 |
+
# which replaces nn.init.kaiming_uniform_ with a no-op. This leaves lora_A
|
| 492 |
+
# as all zeros (from torch.zeros in constructor), making LoRA output = 0
|
| 493 |
+
# and all lora_B gradients = 0. Re-initialize lora_A here.
|
| 494 |
+
_n_reinit = 0
|
| 495 |
+
for _module in model.modules():
|
| 496 |
+
if hasattr(_module, 'lora_A') and hasattr(_module, 'lora_B') and hasattr(_module, 'reset_parameters'):
|
| 497 |
+
nn.init.kaiming_uniform_(_module.lora_A, a=math.sqrt(5))
|
| 498 |
+
_n_reinit += 1
|
| 499 |
+
print(f"[FIX] Re-initialized lora_A in {_n_reinit} LoRA layers with kaiming_uniform_")
|
| 500 |
+
|
| 501 |
try:
|
| 502 |
local_rank = int(os.environ['LOCAL_RANK'])
|
| 503 |
device = torch.device(f"cuda:{local_rank}")
|
| 504 |
except:
|
| 505 |
device = torch.device(f"cuda:0")
|
| 506 |
if model_args.load_checkpoint_from:
|
| 507 |
+
if not os.path.exists(model_args.load_checkpoint_from):
|
| 508 |
+
logger.warning(f"load_checkpoint_from not found: {model_args.load_checkpoint_from}, skipping load")
|
| 509 |
+
else:
|
| 510 |
+
print("----------Loading Previous Query Projection Layer----------")
|
| 511 |
+
model.encoder.trans_input.load_state_dict(torch.load(model_args.load_checkpoint_from, map_location=device))
|
| 512 |
+
if training_args.model_name in ['gainlora_inflora', 'gainlora_olora']:
|
| 513 |
+
model.encoder.previous_trans_input.input_linear[0].data.copy_(torch.load(model_args.load_checkpoint_from, map_location=device)['0.weight'])
|
| 514 |
+
model.encoder.previous_trans_input.output_linear[0].data.copy_(torch.load(model_args.load_checkpoint_from, map_location=device)['2.weight'])
|
| 515 |
+
model.encoder.previous_trans_input.state_dict()
|
| 516 |
+
if cur_task_id > 1:
|
| 517 |
+
model.encoder.previous_trans_input.input_linear[1:].data.copy_(torch.load(model_args.load_checkpoint_from.replace('trans_input.pt', 'previous_trans_input.pt'), map_location=device)['input_linear'])
|
| 518 |
+
model.encoder.previous_trans_input.output_linear[1:].data.copy_(torch.load(model_args.load_checkpoint_from.replace('trans_input.pt', 'previous_trans_input.pt'), map_location=device)['output_linear'])
|
| 519 |
+
print("----------Loading Previous Query Projection Layer Done----------")
|
| 520 |
|
| 521 |
if model_args.previous_lora_path:
|
| 522 |
previous_lora_list = model_args.previous_lora_path.split(',')
|
|
|
|
| 721 |
return result
|
| 722 |
print(f"-----Gradient checkpointing: {training_args.gradient_checkpointing} -----")
|
| 723 |
if training_args.gradient_checkpointing:
|
| 724 |
+
# use_reentrant=False: don't require input requires_grad=True
|
| 725 |
+
# Recommended by PyTorch 2.5+ (will be mandatory in future versions)
|
| 726 |
+
model.gradient_checkpointing_enable(
|
| 727 |
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
| 728 |
+
)
|
| 729 |
|
| 730 |
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 731 |
training_args.step_per_epoch = math.ceil(len(raw_datasets["train"]) / training_args.per_device_train_batch_size / world_size / training_args.gradient_accumulation_steps)
|