| # Copyright (c) Microsoft Corporation. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # DeepSpeed Team | |
| from packaging import version as pkg_version | |
| import torch | |
| def required_torch_version(min_version=None, max_version=None): | |
| assert min_version or max_version, "Must provide a min_version or max_version argument" | |
| torch_version = pkg_version.parse(torch.__version__) | |
| if min_version and pkg_version.parse(str(min_version)) > torch_version: | |
| return False | |
| if max_version and pkg_version.parse(str(max_version)) < torch_version: | |
| return False | |
| return True | |
| def register_grad_hook(param, hook): | |
| if required_torch_version(min_version=2.1): | |
| return param.register_post_accumulate_grad_hook(hook) | |
| else: | |
| param_tmp = param.expand_as(param) | |
| grad_acc = param_tmp.grad_fn.next_functions[0][0] | |
| return grad_acc.register_hook(hook) | |