ViTeX-14B / diffsynth /core /gradient /gradient_checkpoint.py
ViTeX-Bench's picture
Bundle diffsynth library (no external repo dependency)
bc8c4af verified
import torch
import warnings
# Suppress checkpoint requires_grad warning - gradients flow through model params, not inputs
warnings.filterwarnings("ignore", message=".*None of the inputs have requires_grad.*")
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
def gradient_checkpoint_forward(
model,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*args,
**kwargs,
):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=True,
)
elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=True,
)
else:
model_output = model(*args, **kwargs)
return model_output