| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import contextlib |
| | import warnings |
| | from functools import wraps |
| | from types import MappingProxyType |
| | from typing import ( |
| | TYPE_CHECKING, |
| | Any, |
| | Callable, |
| | Dict, |
| | Iterable, |
| | List, |
| | Mapping, |
| | Optional, |
| | TypeVar, |
| | ) |
| |
|
| | import numpy |
| | import torch |
| | from transformers import AutoConfig, PretrainedConfig |
| |
|
| |
|
| | T = TypeVar("T", bound="Callable") |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from compressed_tensors.compressors import ModelCompressor |
| |
|
| |
|
| | __all__ = [ |
| | "infer_compressor_from_model_config", |
| | "fix_fsdp_module_name", |
| | "tensor_follows_mask_structure", |
| | "replace_module", |
| | "is_compressed_tensors_config", |
| | "getattr_chain", |
| | "deprecated", |
| | "Aliasable", |
| | "combine_shards", |
| | "shard_tensor", |
| | "pack_bitmasks", |
| | "unpack_bitmasks", |
| | "patch_attr", |
| | "patch_attrs", |
| | "ParameterizedDefaultDict", |
| | "get_num_attn_heads", |
| | "get_num_kv_heads", |
| | "get_head_dim", |
| | ] |
| |
|
| | FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" |
| |
|
| |
|
| | def infer_compressor_from_model_config( |
| | pretrained_model_name_or_path: str, |
| | ) -> Optional["ModelCompressor"]: |
| | """ |
| | Given a path to a model config, extract a sparsity config if it exists and return |
| | the associated ModelCompressor |
| | |
| | :param pretrained_model_name_or_path: path to model config on disk or HF hub |
| | :return: matching compressor if config contains a sparsity config |
| | """ |
| | from compressed_tensors.compressors import ModelCompressor |
| | from compressed_tensors.config import CompressionConfig |
| |
|
| | config = AutoConfig.from_pretrained(pretrained_model_name_or_path) |
| | sparsity_config = ModelCompressor.parse_sparsity_config(config) |
| | if sparsity_config is None: |
| | return None |
| |
|
| | format = sparsity_config.get("format") |
| | sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config) |
| | compressor = ModelCompressor.load_from_registry(format, config=sparsity_config) |
| | return compressor |
| |
|
| |
|
| | def fix_fsdp_module_name(name: str) -> str: |
| | """ |
| | Remove FSDP wrapper prefixes from a module name |
| | Accounts for scenario where FSDP_WRAPPER_NAME is |
| | at the end of the name, as well as in the middle. |
| | :param name: name to strip |
| | :return: stripped name |
| | """ |
| | return name.replace(FSDP_WRAPPER_NAME + ".", "").replace( |
| | "." + FSDP_WRAPPER_NAME, "" |
| | ) |
| |
|
| |
|
| | def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool: |
| | """ |
| | :param tensor: tensor to check |
| | :param mask: mask structure to check for, in the format "n:m" |
| | :return: True if the tensor follows the mask structure, False otherwise. |
| | Note, some weights can incidentally be zero, so we check for |
| | atleast n zeros in each chunk of size m |
| | """ |
| |
|
| | n, m = tuple(map(int, mask.split(":"))) |
| | |
| | tensor = tensor.view(-1, m) |
| |
|
| | |
| | zero_counts = (tensor == 0).sum(dim=1) |
| |
|
| | |
| | |
| | |
| | if not torch.all(zero_counts >= n).item(): |
| | raise ValueError() |
| |
|
| | return True |
| |
|
| |
|
| | def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module): |
| | if "." in name: |
| | parent_name = name.rsplit(".", 1)[0] |
| | child_name = name[len(parent_name) + 1 :] |
| | parent = model.get_submodule(parent_name) |
| | else: |
| | parent_name = "" |
| | parent = model |
| | child_name = name |
| | setattr(parent, child_name, new_module) |
| |
|
| |
|
| | def is_compressed_tensors_config(compression_config: Any) -> bool: |
| | """ |
| | Returns True if CompressedTensorsConfig is available from transformers and |
| | compression_config is an instance of CompressedTensorsConfig |
| | |
| | See: https://github.com/huggingface/transformers/pull/31704 |
| | """ |
| | try: |
| | from transformers.utils.quantization_config import CompressedTensorsConfig |
| |
|
| | return isinstance(compression_config, CompressedTensorsConfig) |
| | except ImportError: |
| | return False |
| |
|
| |
|
| | def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: |
| | """ |
| | Chain multiple getattr calls, separated by `.` |
| | |
| | :param obj: base object whose attributes are being retrieved |
| | :param chain_str: attribute names separated by `.` |
| | :param default: default value, throw error otherwise |
| | """ |
| | if len(args) >= 1: |
| | has_default = True |
| | default = args[0] |
| | elif "default" in kwargs: |
| | has_default = True |
| | default = kwargs["default"] |
| | else: |
| | has_default = False |
| |
|
| | attr_names = chain_str.split(".") |
| |
|
| | res = obj |
| | for attr_name in attr_names: |
| | if not hasattr(res, attr_name): |
| | if has_default: |
| | return default |
| | else: |
| | raise AttributeError(f"{res} object has no attribute {attr_name}") |
| | res = getattr(res, attr_name) |
| |
|
| | return res |
| |
|
| |
|
| | def deprecated( |
| | future_name: Optional[str] = None, message: Optional[str] = None |
| | ) -> Callable[[T], T]: |
| | """ |
| | Decorator to mark functions as deprecated |
| | |
| | :param new_function: Function called in place of deprecated function |
| | :param message: Deprecation message, replaces default deprecation message |
| | """ |
| |
|
| | def decorator(func: T) -> T: |
| | nonlocal message |
| |
|
| | if message is None: |
| | message = ( |
| | f"{func.__name__} is deprecated and will be removed in a future release" |
| | ) |
| | if future_name is not None: |
| | message += f". Please use {future_name} instead." |
| |
|
| | @wraps(func) |
| | def wrapped(*args, **kwargs): |
| | warnings.warn(message, DeprecationWarning, stacklevel=2) |
| | return func(*args, **kwargs) |
| |
|
| | return wrapped |
| |
|
| | return decorator |
| |
|
| |
|
| | class Aliasable: |
| | """ |
| | A mixin for enums to allow aliasing of enum members |
| | |
| | Example: |
| | >>> class MyClass(Aliasable, int, Enum): |
| | >>> ... |
| | """ |
| |
|
| | @staticmethod |
| | def get_aliases() -> Dict[str, str]: |
| | raise NotImplementedError() |
| |
|
| | def __eq__(self, other): |
| | if isinstance(other, self.__class__): |
| | aliases = self.get_aliases() |
| | return self.value == other.value or ( |
| | aliases.get(self.value, self.value) |
| | == aliases.get(other.value, other.value) |
| | ) |
| | else: |
| | aliases = self.get_aliases() |
| | self_value = aliases.get(self.value, self.value) |
| | other_value = aliases.get(other, other) |
| | return self_value == other_value |
| |
|
| | def __hash__(self): |
| | canonical_value = self.aliases.get(self.value, self.value) |
| | return hash(canonical_value) |
| |
|
| |
|
| | def shard_tensor( |
| | tensor: torch.Tensor, shard_sizes: List[int], dim: int = 0 |
| | ) -> List[torch.Tensor]: |
| | """ |
| | Shards a tensor into a list of tensors along a given dimension. |
| | |
| | raises: ValueError: If the sum of shard_sizes does not match the |
| | size of the tensor along the given dimension. |
| | |
| | :param tensor: The input tensor to shard. |
| | :param shard_sizes : List of sizes for each shard along the specified dimension. |
| | :param dim : The dimension along which to shard the tensor. |
| | :returns: A list of tensors sharded along the specified dimension. |
| | """ |
| | if sum(shard_sizes) != tensor.size(dim): |
| | raise ValueError( |
| | "Sum of shard_sizes must equal the size of the tensor " |
| | "along the specified dimension." |
| | ) |
| |
|
| | shards = [] |
| | start_idx = 0 |
| |
|
| | for size in shard_sizes: |
| | end_idx = start_idx + size |
| | shard = tensor.narrow(dim, start_idx, size) |
| | shards.append(shard) |
| | start_idx = end_idx |
| |
|
| | return shards |
| |
|
| |
|
| | def combine_shards(shards, dim=0): |
| | """ |
| | Combine decompressed shards along a given dimension using `narrow`. |
| | |
| | :param shards: List of decompressed shard tensors. |
| | :param dim: Dimension to combine along (default: 0). |
| | :return: Combined decompressed tensor. |
| | """ |
| | if not shards: |
| | raise ValueError("The list of shards is empty.") |
| |
|
| | |
| | shard_dtypes = {shard.dtype for shard in shards} |
| | if len(shard_dtypes) > 1: |
| | raise ValueError("All shards must have the same dtype.") |
| |
|
| | |
| | total_shape = list(shards[0].shape) |
| | total_shape[dim] = sum(shard.shape[dim] for shard in shards) |
| |
|
| | |
| | combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device) |
| |
|
| | |
| | shard_offset = 0 |
| | for shard in shards: |
| | shard_size = shard.shape[dim] |
| | combined.narrow(dim, shard_offset, shard_size).copy_(shard) |
| | shard_offset += shard_size |
| |
|
| | return combined |
| |
|
| |
|
| | def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be |
| | compressed to R x ceil(C/8) |
| | |
| | :param bytemasks: mask tensor where each byte corresponds to a weight |
| | :return: mask tensor where each bit corresounds to a weight |
| | """ |
| | packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little") |
| | packed_bits_torch = torch.from_numpy(packed_bits_numpy) |
| |
|
| | return packed_bits_torch |
| |
|
| |
|
| | def unpack_bitmasks( |
| | packed_bitmasks: torch.Tensor, original_shape: List[int] |
| | ) -> torch.Tensor: |
| | """ |
| | Converts a bitmask tensor back to a bytemask tensor for use during decompression |
| | |
| | :param packed_bitmasks: mask tensor where each bit corresponds to a weight |
| | :param original_shape: dense shape to decompress to |
| | :return: boolean mask of weights in the original dense shape |
| | """ |
| | |
| | unpacked_bits = numpy.unpackbits( |
| | packed_bitmasks.cpu().numpy(), |
| | axis=-1, |
| | count=original_shape[-1], |
| | bitorder="little", |
| | ) |
| |
|
| | |
| | unpacked_bitmasks_torch = torch.from_numpy( |
| | unpacked_bits.reshape(original_shape).astype(bool) |
| | ) |
| |
|
| | return unpacked_bitmasks_torch |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def patch_attr(base: object, attr: str, value: Any): |
| | """ |
| | Patch the value of an object attribute. Original value is restored upon exit |
| | |
| | :param base: object which has the attribute to patch |
| | :param attr: name of the the attribute to patch |
| | :param value: used to replace original value |
| | |
| | Usage: |
| | >>> from types import SimpleNamespace |
| | >>> obj = SimpleNamespace() |
| | >>> with patch_attr(obj, "attribute", "value"): |
| | ... assert obj.attribute == "value" |
| | >>> assert not hasattr(obj, "attribute") |
| | """ |
| | _sentinel = object() |
| | original_value = getattr(base, attr, _sentinel) |
| |
|
| | setattr(base, attr, value) |
| | try: |
| | yield |
| | finally: |
| | if original_value is not _sentinel: |
| | setattr(base, attr, original_value) |
| | else: |
| | delattr(base, attr) |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def patch_attrs(bases: Iterable[Any], attr: str, values: Iterable[Any]): |
| | """ |
| | Same as `patch_attr` but for a list of objects to patch |
| | Patch attribute for a list of objects with list of values. |
| | Original values are restored upon exit |
| | |
| | :param bases: objects which has the attribute to patch |
| | :param attr: name of the the attribute to patch |
| | :param values: used to replace original values. Must be same |
| | length as bases |
| | |
| | Usage: |
| | >>> from types import SimpleNamespace |
| | >>> obj1 = SimpleNamespace() |
| | >>> obj2 = SimpleNamespace() |
| | >>> with patch_attr([obj1, obj2], "attribute", ["value1", "value2"]): |
| | ... assert obj1.attribute == "value1" |
| | ... assert obj2.attribute == "value2" |
| | >>> assert not hasattr(obj1, "attribute") |
| | >>> assert not hasattr(obj2, "attribute") |
| | """ |
| | with contextlib.ExitStack() as stack: |
| | for base, value in zip(bases, values): |
| | stack.enter_context(patch_attr(base, attr, value)) |
| | yield |
| |
|
| |
|
| | class ParameterizedDefaultDict(dict): |
| | """ |
| | Similar to `collections.DefaultDict`, but upon fetching a key which is missing, |
| | the key is passed as arguments to the `default_factory` |
| | |
| | :param default_factory: function which takes a key as input and returns the |
| | corresponding default value |
| | """ |
| |
|
| | def __init__(self, default_factory: Callable[[Any], Any]): |
| | self.default_factory = default_factory |
| | self._factory_kwargs = MappingProxyType({}) |
| |
|
| | def __missing__(self, key: Any) -> Any: |
| | if isinstance(key, tuple): |
| | value = self.default_factory(*key, **self._factory_kwargs) |
| | else: |
| | value = self.default_factory(key, **self._factory_kwargs) |
| | self[key] = value |
| | return value |
| |
|
| | def get(self, *args, factory_kwargs: Mapping = MappingProxyType({})) -> Any: |
| | """ |
| | Similar to `__getitem__`, but allows passing kwargs to factory function |
| | |
| | :param \\*args: args whose tuple will value will be treated as key |
| | :param factory_kwargs: keyword arguments to pass to `default_factory` |
| | :return: dictionary entry for given key |
| | """ |
| | with patch_attr(self, "_factory_kwargs", factory_kwargs): |
| | return self[args] |
| |
|
| |
|
| | def get_num_attn_heads(config: PretrainedConfig) -> int: |
| | """ |
| | Get the number of attention heads used by a model |
| | |
| | :param config: model config |
| | :return: num_attention_heads of model |
| | """ |
| | if hasattr(config, "num_attention_heads"): |
| | return config.num_attention_heads |
| |
|
| | elif hasattr(config, "hidden_size") and hasattr(config, "head_dim"): |
| | return config.hidden_size // config.head_dim |
| |
|
| | else: |
| | raise ValueError( |
| | "Cannot determine num_attention_heads from config. Config must define " |
| | "either `num_attention_heads` or both `hidden_size` and `head_dim`. " |
| | f"{config}" |
| | ) |
| |
|
| |
|
| | def get_num_kv_heads(config: PretrainedConfig) -> int: |
| | """ |
| | Get the number of key-value attention heads used by a model |
| | |
| | :param config: model config |
| | :return: num_key_value_heads of model |
| | """ |
| | if hasattr(config, "num_key_value_heads"): |
| | return config.num_key_value_heads |
| |
|
| | else: |
| | raise ValueError( |
| | "Cannot determine num_key_value_heads from config. Config must define " |
| | f"`num_key_value_heads`. {config}" |
| | ) |
| |
|
| |
|
| | def get_head_dim(config: PretrainedConfig) -> int: |
| | """ |
| | Get the number of dimensions used by the attention heads of a model |
| | |
| | :param config: model config |
| | :return: head_dim of model |
| | """ |
| | if hasattr(config, "head_dim"): |
| | return config.head_dim |
| |
|
| | elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): |
| | return config.hidden_size // config.num_attention_heads |
| |
|
| | else: |
| | raise ValueError( |
| | "Cannot determine head_dim from config. Config must define " |
| | "either `head_dim` or both `hidden_size` and `num_attention_heads`. " |
| | f"{config}" |
| | ) |
| |
|