convitom commited on
Commit
6a13626
·
1 Parent(s): 8623f6d
Files changed (2) hide show
  1. training/train.py +0 -23
  2. utils/checkpoint.py +6 -0
training/train.py CHANGED
@@ -716,29 +716,6 @@ def run_stage2(model, train_cfg, model_cfg, spec, out_dir, logger,
716
  ))
717
 
718
  if resume_from:
719
- # ── temporary diagnostic: patch bnb load_state_dict to dump
720
- # both current optimizer.param_groups and saved state_dict["param_groups"]
721
- # right before the size check, so we can see the actual mismatch.
722
- try:
723
- import bitsandbytes.optim.optimizer as _bnb_opt
724
- _orig_load = _bnb_opt.Optimizer8bit.load_state_dict
725
- def _patched_load(opt_self, state_dict):
726
- print("\n=== [DIAG] bnb load_state_dict at resume ===")
727
- print(f"CURRENT optimizer.param_groups: {len(opt_self.param_groups)}")
728
- for i, g in enumerate(opt_self.param_groups):
729
- shapes = [tuple(p.shape) for p in g["params"][:3]]
730
- print(f" group {i}: {len(g['params'])} params, "
731
- f"wd={g.get('weight_decay')}, lr={g.get('lr')}, "
732
- f"first3 shapes={shapes}")
733
- print(f"SAVED state_dict['param_groups']: {len(state_dict['param_groups'])}")
734
- for i, g in enumerate(state_dict['param_groups']):
735
- print(f" saved group {i}: {len(g['params'])} params, "
736
- f"wd={g.get('weight_decay')}, lr={g.get('lr')}")
737
- print("=== [DIAG] end ===\n")
738
- return _orig_load(opt_self, state_dict)
739
- _bnb_opt.Optimizer8bit.load_state_dict = _patched_load
740
- except Exception as _e:
741
- logger.warning(f"[DIAG] failed to patch bnb load_state_dict: {_e}")
742
  trainer.train(resume_from_checkpoint=resume_from)
743
  else:
744
  trainer.train()
 
716
  ))
717
 
718
  if resume_from:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719
  trainer.train(resume_from_checkpoint=resume_from)
720
  else:
721
  trainer.train()
utils/checkpoint.py CHANGED
@@ -114,9 +114,15 @@ def load_checkpoint(
114
  )
115
  if lora_dir.exists():
116
  from peft import PeftModel
 
 
 
 
 
117
  model.llm = PeftModel.from_pretrained(
118
  model.llm.base_model.model,
119
  str(lora_dir),
 
120
  )
121
  print(f"[Checkpoint] LoRA adapters loaded ← {lora_dir}")
122
  else:
 
114
  )
115
  if lora_dir.exists():
116
  from peft import PeftModel
117
+ # is_trainable=True is REQUIRED on resume: PEFT defaults to
118
+ # inference mode (requires_grad=False on all LoRA params), which
119
+ # would shrink the trainable set to projection-only (~5 tensors)
120
+ # and break optimizer state loading with a param-group size
121
+ # mismatch against the saved 261-tensor stage-2 optimizer.
122
  model.llm = PeftModel.from_pretrained(
123
  model.llm.base_model.model,
124
  str(lora_dir),
125
+ is_trainable=True,
126
  )
127
  print(f"[Checkpoint] LoRA adapters loaded ← {lora_dir}")
128
  else: