Faaz commited on
Commit
34d0e72
·
1 Parent(s): 6e8f91c

Fix gradient checkpointing: enable_input_require_grads for PEFT without torch.compile

Browse files
Files changed (1) hide show
  1. src/training/mindi_trainer.py +7 -0
src/training/mindi_trainer.py CHANGED
@@ -302,6 +302,13 @@ class MINDITrainer:
302
  if hasattr(base_model, "gradient_checkpointing_enable"):
303
  base_model.gradient_checkpointing_enable()
304
  print("[MINDITrainer] Gradient checkpointing enabled")
 
 
 
 
 
 
 
305
 
306
  # Optional torch.compile (works on ROCm)
307
  if config.use_compile:
 
302
  if hasattr(base_model, "gradient_checkpointing_enable"):
303
  base_model.gradient_checkpointing_enable()
304
  print("[MINDITrainer] Gradient checkpointing enabled")
305
+ # Required for PEFT/LoRA + gradient checkpointing without torch.compile
306
+ if hasattr(self.model.llm, "enable_input_require_grads"):
307
+ self.model.llm.enable_input_require_grads()
308
+ else:
309
+ def _make_inputs_require_grad(module, input, output):
310
+ output.requires_grad_(True)
311
+ self.model.llm.get_input_embeddings().register_forward_hook(_make_inputs_require_grad)
312
 
313
  # Optional torch.compile (works on ROCm)
314
  if config.use_compile: