Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| # | |
| # This file is part of UniRig. | |
| # | |
| # This file is derived from https://github.com/NeuralCarver/Michelangelo | |
| # | |
| # Copyright (c) https://github.com/NeuralCarver/Michelangelo original authors | |
| # Copyright (c) 2025 VAST-AI-Research and contributors. | |
| # | |
| # This program is free software: you can redistribute it and/or modify | |
| # it under the terms of the GNU General Public License as published by | |
| # the Free Software Foundation, either version 3 of the License, or | |
| # (at your option) any later version. | |
| # | |
| # This program is distributed in the hope that it will be useful, | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| # GNU General Public License for more details. | |
| # | |
| # You should have received a copy of the GNU General Public License | |
| # along with this program. If not, see <https://www.gnu.org/licenses/>. | |
| """ | |
| Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 | |
| """ | |
| import torch | |
| from typing import Callable, Iterable, Sequence, Union | |
| from packaging import version | |
| def checkpoint( | |
| func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], | |
| inputs: Sequence[torch.Tensor], | |
| params: Iterable[torch.Tensor], | |
| flag: bool, | |
| use_deepspeed: bool = False | |
| ): | |
| """ | |
| Evaluate a function without caching intermediate activations, allowing for | |
| reduced memory at the expense of extra compute in the backward pass. | |
| :param func: the function to evaluate. | |
| :param inputs: the argument sequence to pass to `func`. | |
| :param params: a sequence of parameters `func` depends on but does not | |
| explicitly take as arguments. | |
| :param flag: if False, disable gradient checkpointing. | |
| :param use_deepspeed: if True, use deepspeed | |
| """ | |
| if flag: | |
| if use_deepspeed: | |
| import deepspeed | |
| return deepspeed.checkpointing.checkpoint(func, *inputs) | |
| args = tuple(inputs) + tuple(params) | |
| return CheckpointFunction.apply(func, len(inputs), *args) | |
| else: | |
| return func(*inputs) | |
| class CheckpointFunction(torch.autograd.Function): | |
| def _get_fwd_decorator(): | |
| if version.parse(torch.__version__) >= version.parse('2.5.0'): | |
| return torch.amp.custom_fwd(device_type='cuda') | |
| else: | |
| return torch.cuda.amp.custom_fwd() | |
| def _get_bwd_decorator(): | |
| if version.parse(torch.__version__) >= version.parse('2.5.0'): | |
| return torch.amp.custom_bwd(device_type='cuda') | |
| else: | |
| def custom_bwd(bwd): | |
| return torch.cuda.amp.custom_bwd(bwd=bwd) | |
| return custom_bwd | |
| def forward(ctx, run_function, length, *args): | |
| ctx.run_function = run_function | |
| ctx.input_tensors = list(args[:length]) | |
| ctx.input_params = list(args[length:]) | |
| with torch.no_grad(): | |
| output_tensors = ctx.run_function(*ctx.input_tensors) | |
| return output_tensors | |
| def backward(ctx, *output_grads): | |
| ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] | |
| with torch.enable_grad(): | |
| # Fixes a bug where the first op in run_function modifies the | |
| # Tensor storage in place, which is not allowed for detach()'d | |
| # Tensors. | |
| shallow_copies = [x.view_as(x) for x in ctx.input_tensors] | |
| output_tensors = ctx.run_function(*shallow_copies) | |
| input_grads = torch.autograd.grad( | |
| output_tensors, | |
| ctx.input_tensors + ctx.input_params, | |
| output_grads, | |
| allow_unused=True, | |
| ) | |
| del ctx.input_tensors | |
| del ctx.input_params | |
| del output_tensors | |
| return (None, None) + input_grads | |