Harmony18090's picture
Add source batch 2/11
76f9669 verified
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import warnings
from functools import wraps
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
TypeVar,
)
import numpy
import torch
from transformers import AutoConfig, PretrainedConfig
T = TypeVar("T", bound="Callable") # used by `deprecated`
if TYPE_CHECKING:
from compressed_tensors.compressors import ModelCompressor
__all__ = [
"infer_compressor_from_model_config",
"fix_fsdp_module_name",
"tensor_follows_mask_structure",
"replace_module",
"is_compressed_tensors_config",
"getattr_chain",
"deprecated",
"Aliasable",
"combine_shards",
"shard_tensor",
"pack_bitmasks",
"unpack_bitmasks",
"patch_attr",
"patch_attrs",
"ParameterizedDefaultDict",
"get_num_attn_heads",
"get_num_kv_heads",
"get_head_dim",
]
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
def infer_compressor_from_model_config(
pretrained_model_name_or_path: str,
) -> Optional["ModelCompressor"]: # noqa: F821
"""
Given a path to a model config, extract a sparsity config if it exists and return
the associated ModelCompressor
:param pretrained_model_name_or_path: path to model config on disk or HF hub
:return: matching compressor if config contains a sparsity config
"""
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import CompressionConfig
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
sparsity_config = ModelCompressor.parse_sparsity_config(config)
if sparsity_config is None:
return None
format = sparsity_config.get("format")
sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
return compressor
def fix_fsdp_module_name(name: str) -> str:
"""
Remove FSDP wrapper prefixes from a module name
Accounts for scenario where FSDP_WRAPPER_NAME is
at the end of the name, as well as in the middle.
:param name: name to strip
:return: stripped name
"""
return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
"." + FSDP_WRAPPER_NAME, ""
)
def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
"""
:param tensor: tensor to check
:param mask: mask structure to check for, in the format "n:m"
:return: True if the tensor follows the mask structure, False otherwise.
Note, some weights can incidentally be zero, so we check for
atleast n zeros in each chunk of size m
"""
n, m = tuple(map(int, mask.split(":")))
# Reshape the tensor into chunks of size m
tensor = tensor.view(-1, m)
# Count the number of zeros in each chunk
zero_counts = (tensor == 0).sum(dim=1)
# Check if the number of zeros in each chunk atleast n
# Greater than sign is needed as some weights can incidentally
# be zero
if not torch.all(zero_counts >= n).item():
raise ValueError()
return True
def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
setattr(parent, child_name, new_module)
def is_compressed_tensors_config(compression_config: Any) -> bool:
"""
Returns True if CompressedTensorsConfig is available from transformers and
compression_config is an instance of CompressedTensorsConfig
See: https://github.com/huggingface/transformers/pull/31704
"""
try:
from transformers.utils.quantization_config import CompressedTensorsConfig
return isinstance(compression_config, CompressedTensorsConfig)
except ImportError:
return False
def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
"""
Chain multiple getattr calls, separated by `.`
:param obj: base object whose attributes are being retrieved
:param chain_str: attribute names separated by `.`
:param default: default value, throw error otherwise
"""
if len(args) >= 1:
has_default = True
default = args[0]
elif "default" in kwargs:
has_default = True
default = kwargs["default"]
else:
has_default = False
attr_names = chain_str.split(".")
res = obj
for attr_name in attr_names:
if not hasattr(res, attr_name):
if has_default:
return default
else:
raise AttributeError(f"{res} object has no attribute {attr_name}")
res = getattr(res, attr_name)
return res
def deprecated(
future_name: Optional[str] = None, message: Optional[str] = None
) -> Callable[[T], T]:
"""
Decorator to mark functions as deprecated
:param new_function: Function called in place of deprecated function
:param message: Deprecation message, replaces default deprecation message
"""
def decorator(func: T) -> T:
nonlocal message
if message is None:
message = (
f"{func.__name__} is deprecated and will be removed in a future release"
)
if future_name is not None:
message += f". Please use {future_name} instead."
@wraps(func)
def wrapped(*args, **kwargs):
warnings.warn(message, DeprecationWarning, stacklevel=2)
return func(*args, **kwargs)
return wrapped
return decorator
class Aliasable:
"""
A mixin for enums to allow aliasing of enum members
Example:
>>> class MyClass(Aliasable, int, Enum):
>>> ...
"""
@staticmethod
def get_aliases() -> Dict[str, str]:
raise NotImplementedError()
def __eq__(self, other):
if isinstance(other, self.__class__):
aliases = self.get_aliases()
return self.value == other.value or (
aliases.get(self.value, self.value)
== aliases.get(other.value, other.value)
)
else:
aliases = self.get_aliases()
self_value = aliases.get(self.value, self.value)
other_value = aliases.get(other, other)
return self_value == other_value
def __hash__(self):
canonical_value = self.aliases.get(self.value, self.value)
return hash(canonical_value)
def shard_tensor(
tensor: torch.Tensor, shard_sizes: List[int], dim: int = 0
) -> List[torch.Tensor]:
"""
Shards a tensor into a list of tensors along a given dimension.
raises: ValueError: If the sum of shard_sizes does not match the
size of the tensor along the given dimension.
:param tensor: The input tensor to shard.
:param shard_sizes : List of sizes for each shard along the specified dimension.
:param dim : The dimension along which to shard the tensor.
:returns: A list of tensors sharded along the specified dimension.
"""
if sum(shard_sizes) != tensor.size(dim):
raise ValueError(
"Sum of shard_sizes must equal the size of the tensor "
"along the specified dimension."
)
shards = []
start_idx = 0
for size in shard_sizes:
end_idx = start_idx + size
shard = tensor.narrow(dim, start_idx, size)
shards.append(shard)
start_idx = end_idx
return shards
def combine_shards(shards, dim=0):
"""
Combine decompressed shards along a given dimension using `narrow`.
:param shards: List of decompressed shard tensors.
:param dim: Dimension to combine along (default: 0).
:return: Combined decompressed tensor.
"""
if not shards:
raise ValueError("The list of shards is empty.")
# Assert that all shards have the same dtype
shard_dtypes = {shard.dtype for shard in shards}
if len(shard_dtypes) > 1:
raise ValueError("All shards must have the same dtype.")
# Determine the total shape of the combined tensor
total_shape = list(shards[0].shape)
total_shape[dim] = sum(shard.shape[dim] for shard in shards)
# Create the combined tensor
combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device)
# Fill the combined tensor using narrow
shard_offset = 0
for shard in shards:
shard_size = shard.shape[dim]
combined.narrow(dim, shard_offset, shard_size).copy_(shard)
shard_offset += shard_size
return combined
def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
"""
Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
compressed to R x ceil(C/8)
:param bytemasks: mask tensor where each byte corresponds to a weight
:return: mask tensor where each bit corresounds to a weight
"""
packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
packed_bits_torch = torch.from_numpy(packed_bits_numpy)
return packed_bits_torch
def unpack_bitmasks(
packed_bitmasks: torch.Tensor, original_shape: List[int]
) -> torch.Tensor:
"""
Converts a bitmask tensor back to a bytemask tensor for use during decompression
:param packed_bitmasks: mask tensor where each bit corresponds to a weight
:param original_shape: dense shape to decompress to
:return: boolean mask of weights in the original dense shape
"""
# Unpack the bits
unpacked_bits = numpy.unpackbits(
packed_bitmasks.cpu().numpy(),
axis=-1,
count=original_shape[-1],
bitorder="little",
)
# Reshape to match the original shape
unpacked_bitmasks_torch = torch.from_numpy(
unpacked_bits.reshape(original_shape).astype(bool)
)
return unpacked_bitmasks_torch
@contextlib.contextmanager
def patch_attr(base: object, attr: str, value: Any):
"""
Patch the value of an object attribute. Original value is restored upon exit
:param base: object which has the attribute to patch
:param attr: name of the the attribute to patch
:param value: used to replace original value
Usage:
>>> from types import SimpleNamespace
>>> obj = SimpleNamespace()
>>> with patch_attr(obj, "attribute", "value"):
... assert obj.attribute == "value"
>>> assert not hasattr(obj, "attribute")
"""
_sentinel = object()
original_value = getattr(base, attr, _sentinel)
setattr(base, attr, value)
try:
yield
finally:
if original_value is not _sentinel:
setattr(base, attr, original_value)
else:
delattr(base, attr)
@contextlib.contextmanager
def patch_attrs(bases: Iterable[Any], attr: str, values: Iterable[Any]):
"""
Same as `patch_attr` but for a list of objects to patch
Patch attribute for a list of objects with list of values.
Original values are restored upon exit
:param bases: objects which has the attribute to patch
:param attr: name of the the attribute to patch
:param values: used to replace original values. Must be same
length as bases
Usage:
>>> from types import SimpleNamespace
>>> obj1 = SimpleNamespace()
>>> obj2 = SimpleNamespace()
>>> with patch_attr([obj1, obj2], "attribute", ["value1", "value2"]):
... assert obj1.attribute == "value1"
... assert obj2.attribute == "value2"
>>> assert not hasattr(obj1, "attribute")
>>> assert not hasattr(obj2, "attribute")
"""
with contextlib.ExitStack() as stack:
for base, value in zip(bases, values):
stack.enter_context(patch_attr(base, attr, value))
yield
class ParameterizedDefaultDict(dict):
"""
Similar to `collections.DefaultDict`, but upon fetching a key which is missing,
the key is passed as arguments to the `default_factory`
:param default_factory: function which takes a key as input and returns the
corresponding default value
"""
def __init__(self, default_factory: Callable[[Any], Any]):
self.default_factory = default_factory
self._factory_kwargs = MappingProxyType({})
def __missing__(self, key: Any) -> Any:
if isinstance(key, tuple):
value = self.default_factory(*key, **self._factory_kwargs)
else:
value = self.default_factory(key, **self._factory_kwargs)
self[key] = value
return value
def get(self, *args, factory_kwargs: Mapping = MappingProxyType({})) -> Any:
"""
Similar to `__getitem__`, but allows passing kwargs to factory function
:param \\*args: args whose tuple will value will be treated as key
:param factory_kwargs: keyword arguments to pass to `default_factory`
:return: dictionary entry for given key
"""
with patch_attr(self, "_factory_kwargs", factory_kwargs):
return self[args]
def get_num_attn_heads(config: PretrainedConfig) -> int:
"""
Get the number of attention heads used by a model
:param config: model config
:return: num_attention_heads of model
"""
if hasattr(config, "num_attention_heads"):
return config.num_attention_heads
elif hasattr(config, "hidden_size") and hasattr(config, "head_dim"):
return config.hidden_size // config.head_dim
else:
raise ValueError(
"Cannot determine num_attention_heads from config. Config must define "
"either `num_attention_heads` or both `hidden_size` and `head_dim`. "
f"{config}"
)
def get_num_kv_heads(config: PretrainedConfig) -> int:
"""
Get the number of key-value attention heads used by a model
:param config: model config
:return: num_key_value_heads of model
"""
if hasattr(config, "num_key_value_heads"):
return config.num_key_value_heads
else:
raise ValueError(
"Cannot determine num_key_value_heads from config. Config must define "
f"`num_key_value_heads`. {config}"
)
def get_head_dim(config: PretrainedConfig) -> int:
"""
Get the number of dimensions used by the attention heads of a model
:param config: model config
:return: head_dim of model
"""
if hasattr(config, "head_dim"):
return config.head_dim
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
return config.hidden_size // config.num_attention_heads
else:
raise ValueError(
"Cannot determine head_dim from config. Config must define "
"either `head_dim` or both `hidden_size` and `num_attention_heads`. "
f"{config}"
)