convitom commited on
Commit ·
6a13626
1
Parent(s): 8623f6d
- training/train.py +0 -23
- utils/checkpoint.py +6 -0
training/train.py
CHANGED
|
@@ -716,29 +716,6 @@ def run_stage2(model, train_cfg, model_cfg, spec, out_dir, logger,
|
|
| 716 |
))
|
| 717 |
|
| 718 |
if resume_from:
|
| 719 |
-
# ── temporary diagnostic: patch bnb load_state_dict to dump
|
| 720 |
-
# both current optimizer.param_groups and saved state_dict["param_groups"]
|
| 721 |
-
# right before the size check, so we can see the actual mismatch.
|
| 722 |
-
try:
|
| 723 |
-
import bitsandbytes.optim.optimizer as _bnb_opt
|
| 724 |
-
_orig_load = _bnb_opt.Optimizer8bit.load_state_dict
|
| 725 |
-
def _patched_load(opt_self, state_dict):
|
| 726 |
-
print("\n=== [DIAG] bnb load_state_dict at resume ===")
|
| 727 |
-
print(f"CURRENT optimizer.param_groups: {len(opt_self.param_groups)}")
|
| 728 |
-
for i, g in enumerate(opt_self.param_groups):
|
| 729 |
-
shapes = [tuple(p.shape) for p in g["params"][:3]]
|
| 730 |
-
print(f" group {i}: {len(g['params'])} params, "
|
| 731 |
-
f"wd={g.get('weight_decay')}, lr={g.get('lr')}, "
|
| 732 |
-
f"first3 shapes={shapes}")
|
| 733 |
-
print(f"SAVED state_dict['param_groups']: {len(state_dict['param_groups'])}")
|
| 734 |
-
for i, g in enumerate(state_dict['param_groups']):
|
| 735 |
-
print(f" saved group {i}: {len(g['params'])} params, "
|
| 736 |
-
f"wd={g.get('weight_decay')}, lr={g.get('lr')}")
|
| 737 |
-
print("=== [DIAG] end ===\n")
|
| 738 |
-
return _orig_load(opt_self, state_dict)
|
| 739 |
-
_bnb_opt.Optimizer8bit.load_state_dict = _patched_load
|
| 740 |
-
except Exception as _e:
|
| 741 |
-
logger.warning(f"[DIAG] failed to patch bnb load_state_dict: {_e}")
|
| 742 |
trainer.train(resume_from_checkpoint=resume_from)
|
| 743 |
else:
|
| 744 |
trainer.train()
|
|
|
|
| 716 |
))
|
| 717 |
|
| 718 |
if resume_from:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
trainer.train(resume_from_checkpoint=resume_from)
|
| 720 |
else:
|
| 721 |
trainer.train()
|
utils/checkpoint.py
CHANGED
|
@@ -114,9 +114,15 @@ def load_checkpoint(
|
|
| 114 |
)
|
| 115 |
if lora_dir.exists():
|
| 116 |
from peft import PeftModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
model.llm = PeftModel.from_pretrained(
|
| 118 |
model.llm.base_model.model,
|
| 119 |
str(lora_dir),
|
|
|
|
| 120 |
)
|
| 121 |
print(f"[Checkpoint] LoRA adapters loaded ← {lora_dir}")
|
| 122 |
else:
|
|
|
|
| 114 |
)
|
| 115 |
if lora_dir.exists():
|
| 116 |
from peft import PeftModel
|
| 117 |
+
# is_trainable=True is REQUIRED on resume: PEFT defaults to
|
| 118 |
+
# inference mode (requires_grad=False on all LoRA params), which
|
| 119 |
+
# would shrink the trainable set to projection-only (~5 tensors)
|
| 120 |
+
# and break optimizer state loading with a param-group size
|
| 121 |
+
# mismatch against the saved 261-tensor stage-2 optimizer.
|
| 122 |
model.llm = PeftModel.from_pretrained(
|
| 123 |
model.llm.base_model.model,
|
| 124 |
str(lora_dir),
|
| 125 |
+
is_trainable=True,
|
| 126 |
)
|
| 127 |
print(f"[Checkpoint] LoRA adapters loaded ← {lora_dir}")
|
| 128 |
else:
|