Harmony18090's picture
Add source batch 2/11
76f9669 verified
raw
history blame
23.5 kB
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities associated with offloading functionality provided by `accelerate`.
| ------------------------------------------------------------------------------------------------------ | # noqa: E501
| Operation | Without offloading support | With offloading support | # noqa: E501
| ---------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
| Check | N/A | has_offloaded_params(module) | # noqa: E501
| Onload | N/A | with align_module_device(module) | # noqa: E501
| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
| Add Module | module.register_module(name, child) | register_offload_module(name, child) | # noqa: E501
| Del Module | del module.name | delete_offload_module(module, name) | # noqa: E501
| ------------------------------------------------------------------------------------------------------ | # noqa: E501
"""
import contextlib
import warnings
from functools import wraps
from operator import attrgetter
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple, Union
import torch
from compressed_tensors.utils import patch_attr
try:
from accelerate.hooks import (
AlignDevicesHook,
add_hook_to_module,
attach_align_device_hook,
named_module_tensors,
remove_hook_from_module,
)
from accelerate.utils import (
OffloadedWeightsLoader,
PrefixedDataset,
find_tied_parameters,
set_module_tensor_to_device,
)
_has_accelerate = True
except ImportError:
_has_accelerate = False
AlignDevicesHook = None
add_hook_to_module = None
remove_hook_from_module = None
OffloadedWeightsLoader = None
PrefixedDataset = None
set_module_tensor_to_device = None
named_module_tensors = None
attach_align_device_hook = None
find_tied_parameters = None
__all__ = [
"get_execution_device",
"get_offloaded_device",
"update_parameter_data",
"register_offload_parameter",
"update_offload_parameter",
"delete_offload_parameter",
"has_offloaded_params",
"disable_hf_hook",
"disable_offload",
"align_modules",
"align_module_device",
"register_offload_module",
"delete_offload_module",
"offloaded_dispatch",
"disable_offloading",
"remove_dispatch",
"cast_to_device",
]
def check_accelerate(fallback: Any):
def decorator(func: Callable[[Any], Any]):
if not _has_accelerate:
if fallback == "error":
@wraps(func)
def fallback_fn(*args, **kwargs):
raise ValueError(
"Please install `accelerate` in order to use this function"
)
else:
@wraps(func)
def fallback_fn(*args, **kwargs):
return fallback
return fallback_fn
return func
return decorator
""" Candidates for Depreciation """
def get_offloaded_device(module: torch.nn.Module) -> torch.device:
"""
:param module: module to check
:return: device module is offloaded to onto after forward pass
"""
if has_offloaded_params(module):
first_key = list(module._hf_hook.weights_map.keys())[0]
prefix_dataset = module._hf_hook.weights_map.dataset
return prefix_dataset[first_key].device
else:
# if the module is not offloaded, then any addded weights
# should be placed the module's execution device
return get_execution_device(module)
def update_parameter_data(
module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str
):
"""
Update the data of an existing parameter and its offload dict. Supports both
parameters of offloaded modules and non-offloaded modules
:param module: module containing the parameter to update
:param new_param_data: tensor to update parameter with
:param param_name: name of module parameter to update
"""
update_offload_parameter(module, param_name, new_param_data)
""" Candidates for Upstreaming """
def cast_to_device(device_spec: Union[int, torch.device]) -> torch.device:
"""
Convert an integer device index or torch.device into a torch.device object.
:param device_spec: Device index (int) or torch.device object.
Negative integers map to CPU.
:return: torch.device corresponding to the given device specification.
"""
if isinstance(device_spec, int):
return torch.device(f"cuda:{device_spec}" if device_spec >= 0 else "cpu")
return device_spec
def get_execution_device(module: torch.nn.Module) -> torch.device:
"""
Get the device which inputs should be moved to before module execution.
Assume that modules execute in the same order as returned by `model.modules()`
:param module: module to check, may be offloaded
:return: onload device of module
"""
for submodule in module.modules():
if has_offloaded_params(submodule):
return cast_to_device(submodule._hf_hook.execution_device)
param = next(submodule.parameters(recurse=False), None)
if param is not None:
return param.device
warnings.warn(f"Unable to get execution device of {module}, falling back to CPU")
return torch.device("cpu")
def register_offload_parameter(
module: torch.nn.Module,
name: str,
parameter: torch.nn.Parameter,
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
):
"""
Register a parameter to the given module which may be offloaded
:param module: maybe offloaded module
:param name: name of newly registered parameter
:param parameter: parameter being registered
:param offload_device: device on which weight will be offloaded to. If None is
provided, then infer device from parameters on module
"""
has_onload = any(p.device != torch.device("meta") for p in module.parameters())
module.register_parameter(name, parameter)
# do everything AlignDevicesHook.init_hook does
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L281
if has_offloaded_params(module):
hook: AlignDevicesHook = module._hf_hook
assert hook.weights_map is not None
# append to original_devices
hook.original_devices[name] = parameter.device
# append to weights map
offload_to_weights_map(hook.weights_map, name, parameter.data, offload_device)
# append to tied_params_map
offloaded = hook.weights_map[name]
if hook.tied_params_map is not None:
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
# perform offloading
if not has_onload:
set_module_tensor_to_device(module, name, "meta")
def update_offload_parameter(
module: torch.nn.Module,
name: str,
data: torch.Tensor,
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
):
"""
Update the data of an existing parameter and its offload dict. Supports both
parameters of offloaded modules and non-offloaded modules
:param module: module containing the parameter to update
:param name: name of module parameter to update
:param data: tensor to update parameter with
:param offload_device: device on which weight will be offloaded to. If None is
provided, then infer device from parameters on module
"""
param: torch.nn.Parameter = getattr(module, name)
if param.data.shape != data.shape:
warnings.warn(
f"Shape of parameter being updated {param.data.shape} does not match shape "
f"of update data {data.shape}"
)
# copy data into onloaded parameter if applicable
if param.device != torch.device("meta") and data is not param.data:
param.data.copy_(data)
# update offload dict
if has_offloaded_params(module):
weights_map = module._hf_hook.weights_map
offload_to_weights_map(weights_map, name, data, offload_device)
def delete_offload_parameter(module: torch.nn.Module, name: str):
"""
Delete a parameter from a module which may be offloaded
:param module: maybe offloaded module
:param name: name of parameter being deleted
"""
delattr(module, name)
if has_offloaded_params(module):
weights_map = module._hf_hook.weights_map
delete_from_weights_map(weights_map, name)
@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def disable_hf_hook(module: torch.nn.Module):
hooks = {}
def collect_hooks(module):
if hasattr(module, "_hf_hook"):
hooks[module] = module._hf_hook
remove_hook_from_module(module)
module.apply(collect_hooks)
yield
for submodule, hook in hooks.items():
add_hook_to_module(submodule, hook)
@check_accelerate(fallback=None)
def offload_to_weights_map(
weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
key: str,
value: torch.Tensor,
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
):
"""
Helper function which implements offloaded item assignment for PrefixedDataset,
OffloadedWeightsLoader, and Dict types.
:param weights_map: weight map to be updated with offload information
:param key: key used to identify weight location
:param value: weight being offloaded
:param offload_device: device on which weight will be offloaded to. If None is
provided, then infer device from parameters in weights_map
"""
if isinstance(weights_map, PrefixedDataset):
if offload_device == "disk":
raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
dataset = weights_map.dataset
key = f"{weights_map.prefix}{key}"
offload_to_weights_map(dataset, key, value, offload_device)
elif isinstance(weights_map, OffloadedWeightsLoader):
if key not in weights_map.all_keys:
weights_map.all_keys.append(key)
if len(weights_map.index) <= 0 and offload_device != "disk":
offload_to_weights_map(weights_map.state_dict, key, value, offload_device)
else:
raise NotImplementedError(
"Updating weights_map with disk offloading is not implemented yet"
)
elif isinstance(weights_map, dict):
if offload_device == "disk":
raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
# infer offload device
if offload_device is None:
if key in weights_map:
offload_device = weights_map[key].device
else:
tens = next(iter(weights_map.values()), None)
if tens is None:
raise ValueError(
"Cannot infer offload device from empty weights_map"
)
offload_device = tens.device
weights_map[key] = value.to(device=offload_device)
else:
raise NotImplementedError(
"Updating offload data not implemented for weights_map of type "
f"{type(weights_map)}"
)
@check_accelerate(fallback=None)
def delete_from_weights_map(
weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
key: str,
):
if isinstance(weights_map, PrefixedDataset):
dataset = weights_map.dataset
key = f"{weights_map.prefix}{key}"
delete_from_weights_map(dataset, key)
elif isinstance(weights_map, OffloadedWeightsLoader):
if len(weights_map.index) <= 0:
delete_from_weights_map(weights_map.state_dict, key)
else:
raise NotImplementedError(
"Delete from weights_map with disk offloading is not implemented yet"
)
elif isinstance(weights_map, dict):
del weights_map[key]
else:
raise NotImplementedError(
"Updating offload data not implemented for weights_map of type "
f"{type(weights_map)}"
)
@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def disable_offload(module: torch.nn.Module):
"""
Context manager to disable module onloading and offloading. Parameters will stay on
their current device
:param module: module to disable offloading for
"""
if has_offloaded_params(module):
module._hf_hook.offload = False
yield
module._hf_hook.offload = True
else:
yield
@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def align_modules(
modules: Union[torch.nn.Module, Iterable[torch.nn.Module]],
execution_device: Optional[torch.device] = None,
):
"""
Context manager for onloading modules to a device, and disabling onload and offload
attempts triggered by forward calls. Used for sequential onloading of layers
:param modules: `torch.nn.Module` or iterable of `torch.nn.Module`s to onload
:param execution_device: device to onload to
"""
modules = (modules,) if isinstance(modules, torch.nn.Module) else modules
with contextlib.ExitStack() as stack:
for module in modules:
stack.enter_context(align_module_device(module, execution_device))
stack.enter_context(disable_offload(module)) # disable redundant onloading
yield
def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
"""
Register a submodule with offloading if the parent module is offloaded
:param base: module to attach submodule to
:param name: name of submodule
:param module: submodule to attach
"""
if has_offloaded_params(base):
hook: AlignDevicesHook = base._hf_hook
assert hook.offload
assert hook.weights_map is not None
# offloading kwargs for submodule
place_submodules = False
offload_buffers = True
# copy device offloading arguments from parent
current_device = next(base.parameters()).device # assume base has parameters
offload_device = get_offloaded_device(base)
# offload parameters to weights map
for param_name, param in named_module_tensors(
module, include_buffers=offload_buffers, recurse=place_submodules
):
offloaded = param.to(offload_device)
if hook.tied_params_map is not None:
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)
# if the parent places submodules, offload here
if hook.place_submodules:
set_module_tensor_to_device(module, param_name, current_device)
# if the parent does not place submodules, then add a hook
# parameters are offloaded by `add_hook_to_module`
if not hook.place_submodules:
weights_map = PrefixedDataset(
hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix}{name}."
)
submodule_hook = AlignDevicesHook(
execution_device=hook.execution_device,
offload=hook.offload,
io_same_device=False,
weights_map=weights_map,
offload_buffers=offload_buffers,
place_submodules=place_submodules,
skip_keys=None,
tied_params_map=hook.tied_params_map,
)
add_hook_to_module(module, submodule_hook)
base.register_module(name, module)
def delete_offload_module(base: torch.nn.Module, name: str):
"""
Delete a submodule from a model which may contain offloading
:param base: parent module to delete submodule from
:param name: name of submodule on parent
"""
module: torch.nn.Module = getattr(base, name)
for param_name, _ in list(module.named_parameters()):
delete_offload_parameter(module, param_name)
delattr(base, name)
@check_accelerate(fallback="error")
def offloaded_dispatch(
module: torch.nn.Module,
execution_device: torch.device,
offload_device: Union[torch.device, Literal["disk"]] = torch.device("cpu"),
) -> torch.nn.Module:
"""
Unlike `dispatch_model`, this function forces a module (and its submodules) to
offload all parameters and replace them with meta tensors, utiliizing the
`AlignDevicesHook` to control onloading and offloading.
:param module: module containing parameters to offload
:param execution_device: device that modules will onload and execute on
:param offload_device: device that module parameters will offload to
:return: module with offloading device hooks
"""
if offload_device == "disk":
raise NotImplementedError("Disk offloading is not currently supported")
# remove any existing hooks
remove_dispatch(module)
# create weights map
state_dict = module.state_dict()
state_dict = {key: val.to(offload_device) for key, val in state_dict.items()}
weights_map = OffloadedWeightsLoader(state_dict=state_dict, device=offload_device)
# create tied params map
tied_params = find_tied_parameters(module)
tied_params_map = {}
for group in tied_params:
for param_name in group:
data_ptr = attrgetter(param_name)(module).data_ptr()
tied_params_map[data_ptr] = {}
# recursively attaches hooks to all submodules
attach_align_device_hook(
module,
execution_device=execution_device,
offload=True,
weights_map=weights_map,
tied_params_map=tied_params_map,
)
# when saving a model, `PretrainedModel.save_pretrained` will only
# onload weights if the following requirements are met
# if (
# hasattr(self, "hf_device_map")
# and len(set(self.hf_device_map.values())) > 1
# and ("cpu" in self.hf_device_map.values()
# or "disk" in self.hf_device_map.values())
# ):
# because this function always offloads, disregard actual devices and
# always use `cpu` and `cuda:0` to guarantee this condition passes
setattr(module, "hf_device_map", {"fake_offload": "cpu", "fake_exec": "cuda:0"})
return module
def remove_dispatch(module: torch.nn.Module) -> torch.nn.Module:
"""
Remove any existing dispatches from module
:param module: module which may be dispatched with hf hooks
:return: module without dispatch
"""
remove_hook_from_module(module, recurse=True)
if hasattr(module, "hf_device_map"):
delattr(module, "hf_device_map")
module.to("cpu")
return module
@contextlib.contextmanager
def disable_offloading():
"""
Keep modules onloaded and disable offloading until this context exits.
Affects modules which have been hooked with accelerate's `AlignDevicesHook`
"""
original_pre_forward = AlignDevicesHook.pre_forward
onloaded_modules: Dict[torch.nn.Module, Tuple[AlignDevicesHook, bool]] = dict()
# onload once and disable any future onloading/offloading steps
def keep_onload_pre_forward(self: AlignDevicesHook, module, *args, **kwargs):
ret = original_pre_forward(self, module, *args, **kwargs)
if module not in onloaded_modules:
onloaded_modules[module] = (self, self.offload)
self.offload = False
return ret
# use the patched pre_forward function within the context
with patch_attr(AlignDevicesHook, "pre_forward", keep_onload_pre_forward):
yield
# manually offload all modules that were onloaded
# update any parameters which may have changed
for module, (hook, offload) in onloaded_modules.items():
hook.offload = offload
for name, param in module.named_parameters(recurse=False):
update_offload_parameter(module, name, param.data)
hook.post_forward(module, None)
""" Upstreamed Functions """
# introduced in accelerate v1.1.0
@check_accelerate(fallback=False)
def has_offloaded_params(module: torch.nn.Module) -> bool:
"""
Checks if a module has offloaded parameters by checking if the given module has a
AlignDevicesHook attached with offloading enabled
Args:
module (`torch.nn.Module`): The module to check for an offload hook.
Returns:
bool: `True` if the module has an offload hook and offloading is enabled,
`False` otherwise.
"""
return (
hasattr(module, "_hf_hook")
and isinstance(module._hf_hook, AlignDevicesHook)
and module._hf_hook.offload
)
# introduced in accelerate v1.1.0
@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def align_module_device(
module: torch.nn.Module, execution_device: Optional[torch.device] = None
):
"""
Context manager that moves a module's parameters to the specified execution device.
Args:
module (`torch.nn.Module`):
Module with parameters to align.
execution_device (`torch.device`, *optional*):
If provided, overrides the module's execution device within the context.
Otherwise, use hook execution device or pass
"""
if has_offloaded_params(module):
if execution_device is not None:
original_device = module._hf_hook.execution_device
module._hf_hook.execution_device = execution_device
try:
module._hf_hook.pre_forward(module)
yield
finally:
module._hf_hook.post_forward(module, None)
if execution_device is not None:
module._hf_hook.execution_device = original_device
elif execution_device is not None:
devices = {
name: param.device for name, param in module.named_parameters(recurse=False)
}
try:
for name in devices:
set_module_tensor_to_device(module, name, execution_device)
yield
finally:
for name, device in devices.items():
set_module_tensor_to_device(module, name, device)
else:
yield
# (1): Since we cannot know which pointers are shared when we add parameters in an
# online way, assume that all pointers are shared. This has virtually no runtime cost