Faaz commited on
Commit ·
34d0e72
1
Parent(s): 6e8f91c
Fix gradient checkpointing: enable_input_require_grads for PEFT without torch.compile
Browse files
src/training/mindi_trainer.py
CHANGED
|
@@ -302,6 +302,13 @@ class MINDITrainer:
|
|
| 302 |
if hasattr(base_model, "gradient_checkpointing_enable"):
|
| 303 |
base_model.gradient_checkpointing_enable()
|
| 304 |
print("[MINDITrainer] Gradient checkpointing enabled")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
# Optional torch.compile (works on ROCm)
|
| 307 |
if config.use_compile:
|
|
|
|
| 302 |
if hasattr(base_model, "gradient_checkpointing_enable"):
|
| 303 |
base_model.gradient_checkpointing_enable()
|
| 304 |
print("[MINDITrainer] Gradient checkpointing enabled")
|
| 305 |
+
# Required for PEFT/LoRA + gradient checkpointing without torch.compile
|
| 306 |
+
if hasattr(self.model.llm, "enable_input_require_grads"):
|
| 307 |
+
self.model.llm.enable_input_require_grads()
|
| 308 |
+
else:
|
| 309 |
+
def _make_inputs_require_grad(module, input, output):
|
| 310 |
+
output.requires_grad_(True)
|
| 311 |
+
self.model.llm.get_input_embeddings().register_forward_hook(_make_inputs_require_grad)
|
| 312 |
|
| 313 |
# Optional torch.compile (works on ROCm)
|
| 314 |
if config.use_compile:
|