Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/compressed_tensors/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/__pycache__/version.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/compressors/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/compressors/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/compressors/__pycache__/helpers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/__init__.py +17 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/__pycache__/model_compressor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/model_compressor.py +466 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/__pycache__/marlin_24.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/config/__init__.py +19 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/dense.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/sparse_24_bitmask.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/sparse_bitmask.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/config/base.py +111 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/config/dense.py +36 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/config/sparse_24_bitmask.py +40 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/config/sparse_bitmask.py +36 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/linear/__init__.py +13 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/linear/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/linear/__pycache__/compressed_linear.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/linear/compressed_linear.py +89 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/registry/__init__.py +17 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/registry/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/registry/__pycache__/registry.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/registry/registry.py +360 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/__init__.py +21 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/helpers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/offload.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/permutations_24.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/permute.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/safetensors_load.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/semi_structured_conversions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/helpers.py +326 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/offload.py +404 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/permutations_24.py +65 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/permute.py +70 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/safetensors_load.py +306 -0
- .venv/lib/python3.11/site-packages/compressed_tensors/utils/semi_structured_conversions.py +342 -0
- .venv/lib/python3.11/site-packages/dotenv/__init__.py +49 -0
- .venv/lib/python3.11/site-packages/dotenv/__main__.py +6 -0
- .venv/lib/python3.11/site-packages/dotenv/__pycache__/__main__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/dotenv/__pycache__/cli.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/dotenv/__pycache__/ipython.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/dotenv/__pycache__/main.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/compressed_tensors/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (447 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (487 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/__pycache__/version.cpython-311.pyc
ADDED
|
Binary file (1.26 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (435 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (7.69 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/__pycache__/helpers.cpython-311.pyc
ADDED
|
Binary file (6.34 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# flake8: noqa
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from .model_compressor import *
|
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (264 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/__pycache__/model_compressor.cpython-311.pyc
ADDED
|
Binary file (19.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/model_compressor.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import operator
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
from contextlib import contextmanager
|
| 21 |
+
from copy import deepcopy
|
| 22 |
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union
|
| 23 |
+
|
| 24 |
+
import compressed_tensors
|
| 25 |
+
import torch
|
| 26 |
+
import transformers
|
| 27 |
+
from compressed_tensors.base import (
|
| 28 |
+
COMPRESSION_VERSION_NAME,
|
| 29 |
+
QUANTIZATION_CONFIG_NAME,
|
| 30 |
+
QUANTIZATION_METHOD_NAME,
|
| 31 |
+
SPARSITY_CONFIG_NAME,
|
| 32 |
+
)
|
| 33 |
+
from compressed_tensors.compressors.base import BaseCompressor
|
| 34 |
+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
| 35 |
+
from compressed_tensors.quantization import (
|
| 36 |
+
DEFAULT_QUANTIZATION_METHOD,
|
| 37 |
+
QuantizationConfig,
|
| 38 |
+
QuantizationStatus,
|
| 39 |
+
apply_quantization_config,
|
| 40 |
+
load_pretrained_quantization,
|
| 41 |
+
)
|
| 42 |
+
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
|
| 43 |
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
| 44 |
+
from compressed_tensors.quantization.utils import (
|
| 45 |
+
is_module_quantized,
|
| 46 |
+
iter_named_leaf_modules,
|
| 47 |
+
)
|
| 48 |
+
from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
|
| 49 |
+
from compressed_tensors.utils.helpers import (
|
| 50 |
+
fix_fsdp_module_name,
|
| 51 |
+
is_compressed_tensors_config,
|
| 52 |
+
)
|
| 53 |
+
from torch import Tensor
|
| 54 |
+
from torch.nn import Module
|
| 55 |
+
from tqdm import tqdm
|
| 56 |
+
from transformers import AutoConfig
|
| 57 |
+
from transformers.file_utils import CONFIG_NAME
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
__all__ = ["ModelCompressor", "map_modules_to_quant_args"]
|
| 61 |
+
|
| 62 |
+
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if TYPE_CHECKING:
|
| 66 |
+
# dummy type if not available from transformers
|
| 67 |
+
CompressedTensorsConfig = TypeVar("CompressedTensorsConfig")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ModelCompressor:
|
| 71 |
+
"""
|
| 72 |
+
Handles compression and decompression of a model with a sparsity config and/or
|
| 73 |
+
quantization config.
|
| 74 |
+
|
| 75 |
+
Compression LifeCycle
|
| 76 |
+
- compressor = ModelCompressor.from_pretrained_model(model)
|
| 77 |
+
- compressed_state_dict = compressor.compress(model, state_dict)
|
| 78 |
+
- compressor.quantization_compressor.compress(model, state_dict)
|
| 79 |
+
- compressor.sparsity_compressor.compress(model, state_dict)
|
| 80 |
+
- model.save_pretrained(output_dir, state_dict=compressed_state_dict)
|
| 81 |
+
- compressor.update_config(output_dir)
|
| 82 |
+
|
| 83 |
+
Decompression LifeCycle
|
| 84 |
+
- compressor = ModelCompressor.from_pretrained(comp_model_path)
|
| 85 |
+
- model = AutoModel.from_pretrained(comp_model_path)
|
| 86 |
+
- compressor.decompress(comp_model_path, model)
|
| 87 |
+
- compressor.sparsity_compressor.decompress(comp_model_path, model)
|
| 88 |
+
- compressor.quantization_compressor.decompress(comp_model_path, model)
|
| 89 |
+
|
| 90 |
+
:param sparsity_config: config specifying sparsity compression parameters
|
| 91 |
+
:param quantization_config: config specifying quantization compression parameters
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
@classmethod
|
| 95 |
+
def from_pretrained(
|
| 96 |
+
cls,
|
| 97 |
+
pretrained_model_name_or_path: str,
|
| 98 |
+
**kwargs,
|
| 99 |
+
) -> Optional["ModelCompressor"]:
|
| 100 |
+
"""
|
| 101 |
+
Given a path to a model config, extract the sparsity and/or quantization
|
| 102 |
+
configs and load a ModelCompressor
|
| 103 |
+
|
| 104 |
+
:param pretrained_model_name_or_path: path to model config on disk or HF hub
|
| 105 |
+
:return: compressor for the configs, or None if model is not compressed
|
| 106 |
+
"""
|
| 107 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 108 |
+
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
|
| 109 |
+
return cls.from_compression_config(compression_config)
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def from_compression_config(
|
| 113 |
+
cls,
|
| 114 |
+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
|
| 115 |
+
):
|
| 116 |
+
"""
|
| 117 |
+
:param compression_config:
|
| 118 |
+
A compression or quantization config
|
| 119 |
+
|
| 120 |
+
The type is one of the following:
|
| 121 |
+
1. A Dict found under either "quantization_config" or "compression_config"
|
| 122 |
+
keys in the config.json
|
| 123 |
+
2. A CompressedTensorsConfig found under key "quantization_config" in HF
|
| 124 |
+
model config
|
| 125 |
+
:return: compressor for the configs, or None if model is not compressed
|
| 126 |
+
"""
|
| 127 |
+
if compression_config is None:
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
sparsity_config = cls.parse_sparsity_config(compression_config)
|
| 131 |
+
quantization_config = cls.parse_quantization_config(compression_config)
|
| 132 |
+
if sparsity_config is None and quantization_config is None:
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
if sparsity_config is not None:
|
| 136 |
+
format = sparsity_config.get("format")
|
| 137 |
+
sparsity_config = SparsityCompressionConfig.load_from_registry(
|
| 138 |
+
format, **sparsity_config
|
| 139 |
+
)
|
| 140 |
+
if quantization_config is not None:
|
| 141 |
+
quantization_config = QuantizationConfig.model_validate(quantization_config)
|
| 142 |
+
|
| 143 |
+
return cls(
|
| 144 |
+
sparsity_config=sparsity_config, quantization_config=quantization_config
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
@classmethod
|
| 148 |
+
def from_pretrained_model(
|
| 149 |
+
cls,
|
| 150 |
+
model: Module,
|
| 151 |
+
sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
|
| 152 |
+
quantization_format: Optional[str] = None,
|
| 153 |
+
) -> Optional["ModelCompressor"]:
|
| 154 |
+
"""
|
| 155 |
+
Given a pytorch model and optional sparsity and/or quantization configs,
|
| 156 |
+
load the appropriate compressors
|
| 157 |
+
|
| 158 |
+
:param model: pytorch model to target for compression
|
| 159 |
+
:param sparsity_config: a filled in sparsity config or string corresponding
|
| 160 |
+
to a sparsity compression algorithm
|
| 161 |
+
:param quantization_format: string corresponding to a quantization compression
|
| 162 |
+
algorithm
|
| 163 |
+
:return: compressor for the configs, or None if model is not compressed
|
| 164 |
+
"""
|
| 165 |
+
quantization_config = QuantizationConfig.from_pretrained(
|
| 166 |
+
model, format=quantization_format
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
if isinstance(sparsity_config, str): # we passed in a sparsity format
|
| 170 |
+
sparsity_config = SparsityCompressionConfig.load_from_registry(
|
| 171 |
+
sparsity_config
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if sparsity_config is None and quantization_config is None:
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
return cls(
|
| 178 |
+
sparsity_config=sparsity_config, quantization_config=quantization_config
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
@staticmethod
|
| 182 |
+
def parse_sparsity_config(
|
| 183 |
+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
|
| 184 |
+
) -> Union[Dict[str, Any], None]:
|
| 185 |
+
"""
|
| 186 |
+
Parse sparsity config from quantization/compression config. Sparsity
|
| 187 |
+
config is nested inside q/c config
|
| 188 |
+
|
| 189 |
+
:param compression_config: quantization/compression config
|
| 190 |
+
:return: sparsity config
|
| 191 |
+
"""
|
| 192 |
+
if compression_config is None:
|
| 193 |
+
return None
|
| 194 |
+
|
| 195 |
+
if is_compressed_tensors_config(compression_config):
|
| 196 |
+
s_config = compression_config.sparsity_config
|
| 197 |
+
return s_config.model_dump() if s_config is not None else None
|
| 198 |
+
|
| 199 |
+
return compression_config.get(SPARSITY_CONFIG_NAME, None)
|
| 200 |
+
|
| 201 |
+
@staticmethod
|
| 202 |
+
def parse_quantization_config(
|
| 203 |
+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
|
| 204 |
+
) -> Union[Dict[str, Any], None]:
|
| 205 |
+
"""
|
| 206 |
+
Parse quantization config from quantization/compression config. The
|
| 207 |
+
quantization are all the fields that are not the sparsity config or
|
| 208 |
+
metadata fields
|
| 209 |
+
|
| 210 |
+
:param compression_config: quantization/compression config
|
| 211 |
+
:return: quantization config without sparsity config or metadata fields
|
| 212 |
+
"""
|
| 213 |
+
if compression_config is None:
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
if is_compressed_tensors_config(compression_config):
|
| 217 |
+
q_config = compression_config.quantization_config
|
| 218 |
+
return q_config.model_dump() if q_config is not None else None
|
| 219 |
+
|
| 220 |
+
quantization_config = deepcopy(compression_config)
|
| 221 |
+
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
|
| 222 |
+
|
| 223 |
+
# some fields are required, even if a qconfig is not present
|
| 224 |
+
# pop them off and if nothing remains, then there is no qconfig
|
| 225 |
+
quant_method = quantization_config.pop(QUANTIZATION_METHOD_NAME, None)
|
| 226 |
+
_ = quantization_config.pop(COMPRESSION_VERSION_NAME, None)
|
| 227 |
+
|
| 228 |
+
if len(quantization_config) == 0:
|
| 229 |
+
return None
|
| 230 |
+
|
| 231 |
+
# replace popped off values
|
| 232 |
+
# note that version is discarded for now
|
| 233 |
+
if quant_method is not None:
|
| 234 |
+
quantization_config[QUANTIZATION_METHOD_NAME] = quant_method
|
| 235 |
+
|
| 236 |
+
return quantization_config
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
sparsity_config: Optional[SparsityCompressionConfig] = None,
|
| 241 |
+
quantization_config: Optional[QuantizationConfig] = None,
|
| 242 |
+
):
|
| 243 |
+
self.sparsity_config = sparsity_config
|
| 244 |
+
self.quantization_config = quantization_config
|
| 245 |
+
self.sparsity_compressor = None
|
| 246 |
+
self.quantization_compressor = None
|
| 247 |
+
|
| 248 |
+
if sparsity_config is not None:
|
| 249 |
+
self.sparsity_compressor = BaseCompressor.load_from_registry(
|
| 250 |
+
sparsity_config.format, config=sparsity_config
|
| 251 |
+
)
|
| 252 |
+
if quantization_config is not None:
|
| 253 |
+
self.quantization_compressor = BaseCompressor.load_from_registry(
|
| 254 |
+
quantization_config.format, config=quantization_config
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
def compress(
|
| 258 |
+
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
|
| 259 |
+
) -> Dict[str, Tensor]:
|
| 260 |
+
"""
|
| 261 |
+
Compresses a dense state dict or model with sparsity and/or quantization
|
| 262 |
+
|
| 263 |
+
:param model: uncompressed model to compress
|
| 264 |
+
:param state_dict: optional uncompressed state_dict to insert into model
|
| 265 |
+
:return: compressed state dict
|
| 266 |
+
"""
|
| 267 |
+
if state_dict is None:
|
| 268 |
+
state_dict = model.state_dict()
|
| 269 |
+
|
| 270 |
+
compressed_state_dict = state_dict
|
| 271 |
+
|
| 272 |
+
quantized_modules_to_args: Dict[
|
| 273 |
+
str, QuantizationArgs
|
| 274 |
+
] = map_modules_to_quant_args(model)
|
| 275 |
+
|
| 276 |
+
if self.quantization_compressor is not None:
|
| 277 |
+
compressed_state_dict = self.quantization_compressor.compress(
|
| 278 |
+
state_dict, names_to_scheme=quantized_modules_to_args
|
| 279 |
+
)
|
| 280 |
+
if self.quantization_config.format != CompressionFormat.dense.value:
|
| 281 |
+
self.quantization_config.quantization_status = (
|
| 282 |
+
QuantizationStatus.COMPRESSED
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
if self.sparsity_compressor is not None:
|
| 286 |
+
sparse_compression_targets: Set[str] = expand_sparse_target_names(
|
| 287 |
+
model=model,
|
| 288 |
+
targets=self.sparsity_config.targets,
|
| 289 |
+
ignore=self.sparsity_config.ignore,
|
| 290 |
+
)
|
| 291 |
+
compressed_state_dict = self.sparsity_compressor.compress(
|
| 292 |
+
compressed_state_dict,
|
| 293 |
+
compression_targets=sparse_compression_targets,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# HACK: Override the dtype_byte_size function in transformers to
|
| 297 |
+
# support float8 types. Fix is posted upstream
|
| 298 |
+
# https://github.com/huggingface/transformers/pull/30488
|
| 299 |
+
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
|
| 300 |
+
|
| 301 |
+
return compressed_state_dict
|
| 302 |
+
|
| 303 |
+
def decompress(self, model_path: str, model: Module):
|
| 304 |
+
"""
|
| 305 |
+
Overwrites the weights in model with weights decompressed from model_path
|
| 306 |
+
|
| 307 |
+
:param model_path: path to compressed weights
|
| 308 |
+
:param model: pytorch model to load decompressed weights into
|
| 309 |
+
"""
|
| 310 |
+
model_path = get_safetensors_folder(model_path)
|
| 311 |
+
sparse_decompressed = False
|
| 312 |
+
|
| 313 |
+
if (
|
| 314 |
+
self.sparsity_compressor is not None
|
| 315 |
+
and self.sparsity_config.format != CompressionFormat.dense.value
|
| 316 |
+
):
|
| 317 |
+
# Sparse decompression is applied on the model_path
|
| 318 |
+
dense_gen = self.sparsity_compressor.decompress(model_path)
|
| 319 |
+
self._replace_weights(dense_gen, model)
|
| 320 |
+
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
|
| 321 |
+
sparse_decompressed = True
|
| 322 |
+
|
| 323 |
+
if self.quantization_compressor is not None:
|
| 324 |
+
# Temporarily set quantization status to FROZEN to prevent
|
| 325 |
+
# quantization during apply_quantization_config. This ensures
|
| 326 |
+
# that the dtypes of the weights are not unintentionally updated.
|
| 327 |
+
# The status is restored after quantization params are loaded.
|
| 328 |
+
with override_quantization_status(
|
| 329 |
+
self.quantization_config, QuantizationStatus.FROZEN
|
| 330 |
+
):
|
| 331 |
+
names_to_scheme = apply_quantization_config(
|
| 332 |
+
model, self.quantization_config
|
| 333 |
+
)
|
| 334 |
+
load_pretrained_quantization(model, model_path)
|
| 335 |
+
|
| 336 |
+
model_path_or_state_dict = (
|
| 337 |
+
model.state_dict() if sparse_decompressed else model_path
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
dense_gen = self.quantization_compressor.decompress(
|
| 341 |
+
model_path_or_state_dict, names_to_scheme=names_to_scheme
|
| 342 |
+
)
|
| 343 |
+
self._replace_weights(dense_gen, model)
|
| 344 |
+
|
| 345 |
+
def freeze_quantization_status(module):
|
| 346 |
+
module.quantization_status = QuantizationStatus.FROZEN
|
| 347 |
+
|
| 348 |
+
model.apply(freeze_quantization_status)
|
| 349 |
+
setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
|
| 350 |
+
|
| 351 |
+
def update_config(self, save_directory: str):
|
| 352 |
+
"""
|
| 353 |
+
Update the model config located at save_directory with compression configs
|
| 354 |
+
for sparsity and/or quantization
|
| 355 |
+
|
| 356 |
+
:param save_directory: path to a folder containing a HF model config
|
| 357 |
+
"""
|
| 358 |
+
if self.quantization_config is None and self.sparsity_config is None:
|
| 359 |
+
return
|
| 360 |
+
|
| 361 |
+
config_file_path = os.path.join(save_directory, CONFIG_NAME)
|
| 362 |
+
if not os.path.exists(config_file_path):
|
| 363 |
+
_LOGGER.warning(
|
| 364 |
+
f"Could not find a valid model config file in "
|
| 365 |
+
f"{save_directory}. Compression config will not be saved."
|
| 366 |
+
)
|
| 367 |
+
return
|
| 368 |
+
|
| 369 |
+
with open(config_file_path, "r") as config_file:
|
| 370 |
+
config_data = json.load(config_file)
|
| 371 |
+
|
| 372 |
+
# required metadata whenever a quantization or sparsity config is present
|
| 373 |
+
# overwrite previous config and version if already existing
|
| 374 |
+
config_data[QUANTIZATION_CONFIG_NAME] = {}
|
| 375 |
+
config_data[QUANTIZATION_CONFIG_NAME][
|
| 376 |
+
COMPRESSION_VERSION_NAME
|
| 377 |
+
] = compressed_tensors.__version__
|
| 378 |
+
if self.quantization_config is not None:
|
| 379 |
+
self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
|
| 380 |
+
else:
|
| 381 |
+
config_data[QUANTIZATION_CONFIG_NAME][
|
| 382 |
+
QUANTIZATION_METHOD_NAME
|
| 383 |
+
] = DEFAULT_QUANTIZATION_METHOD
|
| 384 |
+
|
| 385 |
+
# quantization and sparsity configs
|
| 386 |
+
if self.quantization_config is not None:
|
| 387 |
+
quant_config_data = self.quantization_config.model_dump()
|
| 388 |
+
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
|
| 389 |
+
if self.sparsity_config is not None:
|
| 390 |
+
sparsity_config_data = self.sparsity_config.model_dump()
|
| 391 |
+
config_data[QUANTIZATION_CONFIG_NAME][
|
| 392 |
+
SPARSITY_CONFIG_NAME
|
| 393 |
+
] = sparsity_config_data
|
| 394 |
+
|
| 395 |
+
with open(config_file_path, "w") as config_file:
|
| 396 |
+
json.dump(config_data, config_file, indent=2, sort_keys=True)
|
| 397 |
+
|
| 398 |
+
def _replace_weights(self, dense_weight_generator, model: Module):
|
| 399 |
+
"""
|
| 400 |
+
Replace the weights of the model with the
|
| 401 |
+
provided dense weights.
|
| 402 |
+
|
| 403 |
+
This method iterates over the dense_weight_generator and
|
| 404 |
+
updates the corresponding weights in the model. If a parameter
|
| 405 |
+
name does not exist in the model, it will be skipped.
|
| 406 |
+
|
| 407 |
+
:param dense_weight_generator (generator): A generator that yields
|
| 408 |
+
tuples of (name, data), where 'name' is the parameter name and
|
| 409 |
+
'data' is the updated param data
|
| 410 |
+
:param model: The model whose weights are to be updated.
|
| 411 |
+
"""
|
| 412 |
+
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
|
| 413 |
+
split_name = name.split(".")
|
| 414 |
+
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
|
| 415 |
+
module = operator.attrgetter(prefix)(model)
|
| 416 |
+
if hasattr(module, param_name):
|
| 417 |
+
update_parameter_data(module, data, param_name)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
|
| 421 |
+
"""
|
| 422 |
+
Given a pytorch model, map out the submodule name (usually linear layers)
|
| 423 |
+
to the QuantizationArgs
|
| 424 |
+
|
| 425 |
+
:param model: pytorch model
|
| 426 |
+
"""
|
| 427 |
+
quantized_modules_to_args = {}
|
| 428 |
+
for name, submodule in iter_named_leaf_modules(model):
|
| 429 |
+
if is_module_quantized(submodule):
|
| 430 |
+
if submodule.quantization_scheme.weights is not None:
|
| 431 |
+
name = fix_fsdp_module_name(name)
|
| 432 |
+
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
|
| 433 |
+
|
| 434 |
+
return quantized_modules_to_args
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# HACK: Override the dtype_byte_size function in transformers to support float8 types
|
| 438 |
+
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
|
| 439 |
+
def new_dtype_byte_size(dtype):
|
| 440 |
+
if dtype == torch.bool:
|
| 441 |
+
return 1 / 8
|
| 442 |
+
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
|
| 443 |
+
if bit_search is None:
|
| 444 |
+
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
| 445 |
+
bit_size = int(bit_search.groups()[0])
|
| 446 |
+
return bit_size // 8
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
@contextmanager
|
| 450 |
+
def override_quantization_status(
|
| 451 |
+
config: QuantizationConfig, status: QuantizationStatus
|
| 452 |
+
):
|
| 453 |
+
"""
|
| 454 |
+
Within this context, the quantization status will be set to the
|
| 455 |
+
supplied status. After the context exits, the original status
|
| 456 |
+
will be restored.
|
| 457 |
+
|
| 458 |
+
:param config: the quantization config to override
|
| 459 |
+
:param status: the status to temporarily set
|
| 460 |
+
"""
|
| 461 |
+
original_status = config.quantization_status
|
| 462 |
+
config.quantization_status = status
|
| 463 |
+
try:
|
| 464 |
+
yield
|
| 465 |
+
finally:
|
| 466 |
+
config.quantization_status = original_status
|
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (300 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/__pycache__/marlin_24.cpython-311.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/config/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# flake8: noqa
|
| 16 |
+
from .base import *
|
| 17 |
+
from .dense import *
|
| 18 |
+
from .sparse_24_bitmask import *
|
| 19 |
+
from .sparse_bitmask import *
|
.venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (331 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/dense.cpython-311.pyc
ADDED
|
Binary file (1.36 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/sparse_24_bitmask.cpython-311.pyc
ADDED
|
Binary file (1.48 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/sparse_bitmask.cpython-311.pyc
ADDED
|
Binary file (1.35 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/config/base.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from enum import Enum, unique
|
| 16 |
+
from typing import List, Optional
|
| 17 |
+
|
| 18 |
+
from compressed_tensors.registry import RegistryMixin
|
| 19 |
+
from pydantic import BaseModel
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
__all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@unique
|
| 26 |
+
class CompressionFormat(Enum):
|
| 27 |
+
dense = "dense"
|
| 28 |
+
sparse_bitmask = "sparse-bitmask"
|
| 29 |
+
sparse_24_bitmask = "sparse-24-bitmask"
|
| 30 |
+
int_quantized = "int-quantized"
|
| 31 |
+
float_quantized = "float-quantized"
|
| 32 |
+
naive_quantized = "naive-quantized"
|
| 33 |
+
pack_quantized = "pack-quantized"
|
| 34 |
+
marlin_24 = "marlin-24"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@unique
|
| 38 |
+
class SparsityStructure(Enum):
|
| 39 |
+
"""
|
| 40 |
+
An enumeration to represent different sparsity structures.
|
| 41 |
+
|
| 42 |
+
Attributes
|
| 43 |
+
----------
|
| 44 |
+
TWO_FOUR : str
|
| 45 |
+
Represents a 2:4 sparsity structure.
|
| 46 |
+
ZERO_ZERO : str
|
| 47 |
+
Represents a 0:0 sparsity structure.
|
| 48 |
+
UNSTRUCTURED : str
|
| 49 |
+
Represents an unstructured sparsity structure.
|
| 50 |
+
|
| 51 |
+
Examples
|
| 52 |
+
--------
|
| 53 |
+
>>> SparsityStructure('2:4')
|
| 54 |
+
<SparsityStructure.TWO_FOUR: '2:4'>
|
| 55 |
+
|
| 56 |
+
>>> SparsityStructure('unstructured')
|
| 57 |
+
<SparsityStructure.UNSTRUCTURED: 'unstructured'>
|
| 58 |
+
|
| 59 |
+
>>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
|
| 60 |
+
True
|
| 61 |
+
|
| 62 |
+
>>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
|
| 63 |
+
True
|
| 64 |
+
|
| 65 |
+
>>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
|
| 66 |
+
True
|
| 67 |
+
|
| 68 |
+
>>> SparsityStructure('invalid')
|
| 69 |
+
Traceback (most recent call last):
|
| 70 |
+
...
|
| 71 |
+
ValueError: invalid is not a valid SparsityStructure
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
TWO_FOUR = "2:4"
|
| 75 |
+
UNSTRUCTURED = "unstructured"
|
| 76 |
+
ZERO_ZERO = "0:0"
|
| 77 |
+
|
| 78 |
+
def __new__(cls, value):
|
| 79 |
+
obj = object.__new__(cls)
|
| 80 |
+
obj._value_ = value.lower() if value is not None else value
|
| 81 |
+
return obj
|
| 82 |
+
|
| 83 |
+
@classmethod
|
| 84 |
+
def _missing_(cls, value):
|
| 85 |
+
# Handle None and case-insensitive values
|
| 86 |
+
if value is None:
|
| 87 |
+
return cls.UNSTRUCTURED
|
| 88 |
+
for member in cls:
|
| 89 |
+
if member.value == value.lower():
|
| 90 |
+
return member
|
| 91 |
+
raise ValueError(f"{value} is not a valid {cls.__name__}")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class SparsityCompressionConfig(RegistryMixin, BaseModel):
|
| 95 |
+
"""
|
| 96 |
+
Base data class for storing sparsity compression parameters
|
| 97 |
+
|
| 98 |
+
:param format: name of compression format
|
| 99 |
+
:param targets: List of layer names or layer types that aren't sparse and should
|
| 100 |
+
be ignored during compression. By default, assume all layers are targeted
|
| 101 |
+
:param ignore: List of layer names (unique) to ignore from targets. Defaults to None
|
| 102 |
+
:param global_sparsity: average sparsity of the entire model
|
| 103 |
+
:param sparsity_structure: structure of the sparsity, such as
|
| 104 |
+
"unstructured", "2:4", "8:16" etc
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
format: str
|
| 108 |
+
targets: Optional[List[str]] = None
|
| 109 |
+
ignore: Optional[List[str]] = None
|
| 110 |
+
global_sparsity: Optional[float] = 0.0
|
| 111 |
+
sparsity_structure: Optional[str] = "unstructured"
|
.venv/lib/python3.11/site-packages/compressed_tensors/config/dense.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
__all__ = ["DenseSparsityConfig"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@SparsityCompressionConfig.register(name=CompressionFormat.dense.value)
|
| 24 |
+
class DenseSparsityConfig(SparsityCompressionConfig):
|
| 25 |
+
"""
|
| 26 |
+
Identity configuration for storing a sparse model in
|
| 27 |
+
an uncompressed dense format
|
| 28 |
+
|
| 29 |
+
:param global_sparsity: average sparsity of the entire model
|
| 30 |
+
:param sparsity_structure: structure of the sparsity, such as
|
| 31 |
+
"unstructured", "2:4", "8:16" etc
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
format: str = CompressionFormat.dense.value
|
| 35 |
+
global_sparsity: Optional[float] = 0.0
|
| 36 |
+
sparsity_structure: Optional[str] = "unstructured"
|
.venv/lib/python3.11/site-packages/compressed_tensors/config/sparse_24_bitmask.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
from compressed_tensors.config import (
|
| 18 |
+
CompressionFormat,
|
| 19 |
+
SparsityCompressionConfig,
|
| 20 |
+
SparsityStructure,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
__all__ = ["Sparse24BitMaskConfig"]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24_bitmask.value)
|
| 28 |
+
class Sparse24BitMaskConfig(SparsityCompressionConfig):
|
| 29 |
+
"""
|
| 30 |
+
Configuration for storing a 24 sparse model using
|
| 31 |
+
bytemask compression
|
| 32 |
+
|
| 33 |
+
:param global_sparsity: average sparsity of the entire model
|
| 34 |
+
:param sparsity_structure: structure of the sparsity, should always be
|
| 35 |
+
"2:4" for this compression format
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
format: str = CompressionFormat.sparse_24_bitmask.value
|
| 39 |
+
global_sparsity: Optional[float] = 0.0
|
| 40 |
+
sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value
|
.venv/lib/python3.11/site-packages/compressed_tensors/config/sparse_bitmask.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
__all__ = ["BitmaskConfig"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@SparsityCompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
|
| 24 |
+
class BitmaskConfig(SparsityCompressionConfig):
|
| 25 |
+
"""
|
| 26 |
+
Configuration for storing a sparse model using
|
| 27 |
+
bitmask compression
|
| 28 |
+
|
| 29 |
+
:param global_sparsity: average sparsity of the entire model
|
| 30 |
+
:param sparsity_structure: structure of the sparsity, such as
|
| 31 |
+
"unstructured", "2:4", "8:16" etc
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
format: str = CompressionFormat.sparse_bitmask.value
|
| 35 |
+
global_sparsity: Optional[float] = 0.0
|
| 36 |
+
sparsity_structure: Optional[str] = "unstructured"
|
.venv/lib/python3.11/site-packages/compressed_tensors/linear/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
.venv/lib/python3.11/site-packages/compressed_tensors/linear/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/linear/__pycache__/compressed_linear.cpython-311.pyc
ADDED
|
Binary file (3.63 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/linear/compressed_linear.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Dict, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from compressed_tensors.compressors.base import BaseCompressor
|
| 19 |
+
from compressed_tensors.quantization import (
|
| 20 |
+
QuantizationScheme,
|
| 21 |
+
QuantizationStatus,
|
| 22 |
+
initialize_module_for_quantization,
|
| 23 |
+
)
|
| 24 |
+
from torch import Tensor
|
| 25 |
+
from torch.nn import Parameter
|
| 26 |
+
from torch.nn.functional import linear
|
| 27 |
+
from torch.nn.modules import Linear
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CompressedLinear(Linear):
|
| 31 |
+
"""
|
| 32 |
+
Wrapper module for running a compressed forward pass of a quantized Linear module.
|
| 33 |
+
The wrapped layer will decompressed on each forward call.
|
| 34 |
+
|
| 35 |
+
:param module: dense linear module to replace
|
| 36 |
+
:param quantization_scheme: quantization config for the module to wrap
|
| 37 |
+
:param quantization_format: compression format module is stored as
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
@classmethod
|
| 41 |
+
@torch.no_grad()
|
| 42 |
+
def from_linear(
|
| 43 |
+
cls,
|
| 44 |
+
module: Linear,
|
| 45 |
+
quantization_scheme: QuantizationScheme,
|
| 46 |
+
quantization_format: str,
|
| 47 |
+
):
|
| 48 |
+
module.__class__ = CompressedLinear
|
| 49 |
+
module.compressor = BaseCompressor.load_from_registry(quantization_format)
|
| 50 |
+
device = next(module.parameters()).device
|
| 51 |
+
|
| 52 |
+
# this will initialize all the scales and zero points
|
| 53 |
+
initialize_module_for_quantization(
|
| 54 |
+
module, quantization_scheme, force_zero_point=False
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# get the shape and dtype of compressed parameters
|
| 58 |
+
compression_params: Dict[str, Tuple] = module.compressor.compression_param_info(
|
| 59 |
+
module.weight.shape, quantization_scheme.weights
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# no need for this once quantization is initialized, will be replaced
|
| 63 |
+
# with the compressed parameter
|
| 64 |
+
delattr(module, "weight")
|
| 65 |
+
|
| 66 |
+
# populate compressed weights and quantization parameters
|
| 67 |
+
for name, (shape, dtype) in compression_params.items():
|
| 68 |
+
param = Parameter(
|
| 69 |
+
torch.empty(shape, device=device, dtype=dtype), requires_grad=False
|
| 70 |
+
)
|
| 71 |
+
module.register_parameter(name, param)
|
| 72 |
+
|
| 73 |
+
# mark module as compressed
|
| 74 |
+
module.quantization_status = QuantizationStatus.COMPRESSED
|
| 75 |
+
|
| 76 |
+
# handles case where forward is wrapped in new_forward by accelerate hooks
|
| 77 |
+
if hasattr(module, "_old_forward"):
|
| 78 |
+
module._old_forward = CompressedLinear.forward.__get__(
|
| 79 |
+
module, CompressedLinear
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return module
|
| 83 |
+
|
| 84 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 85 |
+
"""
|
| 86 |
+
Decompresses the weight, then runs the wrapped forward pass
|
| 87 |
+
"""
|
| 88 |
+
uncompressed_weight = self.compressor.decompress_module(self)
|
| 89 |
+
return linear(input, uncompressed_weight, self.bias)
|
.venv/lib/python3.11/site-packages/compressed_tensors/registry/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing,
|
| 12 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from .registry import *
|
.venv/lib/python3.11/site-packages/compressed_tensors/registry/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (235 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/registry/__pycache__/registry.cpython-311.pyc
ADDED
|
Binary file (14.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/registry/registry.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Universal registry to support registration and loading of child classes and plugins
|
| 17 |
+
of neuralmagic utilities
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import importlib
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
from typing import Any, Dict, List, Optional, Type, Union
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
"RegistryMixin",
|
| 27 |
+
"register",
|
| 28 |
+
"get_from_registry",
|
| 29 |
+
"registered_names",
|
| 30 |
+
"registered_aliases",
|
| 31 |
+
"standardize_lookup_name",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
_ALIAS_REGISTRY: Dict[Type, Dict[str, str]] = defaultdict(dict)
|
| 36 |
+
_REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def standardize_lookup_name(name: str) -> str:
|
| 40 |
+
"""
|
| 41 |
+
Standardize the given name for lookup in the registry.
|
| 42 |
+
This will replace all underscores and spaces with hyphens and
|
| 43 |
+
convert the name to lowercase.
|
| 44 |
+
|
| 45 |
+
example:
|
| 46 |
+
```
|
| 47 |
+
standardize_lookup_name("Foo_bar baz") == "foo-bar-baz"
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
:param name: name to standardize
|
| 51 |
+
:return: standardized name
|
| 52 |
+
"""
|
| 53 |
+
return name.replace("_", "-").replace(" ", "-").lower()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def standardize_alias_name(
|
| 57 |
+
name: Union[None, str, List[str]]
|
| 58 |
+
) -> Union[None, str, List[str]]:
|
| 59 |
+
if name is None:
|
| 60 |
+
return None
|
| 61 |
+
elif isinstance(name, str):
|
| 62 |
+
return standardize_lookup_name(name)
|
| 63 |
+
else: # isinstance(name, list)
|
| 64 |
+
return [standardize_lookup_name(n) for n in name]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class RegistryMixin:
|
| 68 |
+
"""
|
| 69 |
+
Universal registry to support registration and loading of child classes and plugins
|
| 70 |
+
of neuralmagic utilities.
|
| 71 |
+
|
| 72 |
+
Classes that require a registry or plugins may add the `RegistryMixin` and use
|
| 73 |
+
`register` and `load` as the main entrypoints for adding new implementations and
|
| 74 |
+
loading requested values from its registry.
|
| 75 |
+
|
| 76 |
+
If a class should only have its child classes in its registry, the class should
|
| 77 |
+
set the static attribute `registry_requires_subclass` to True
|
| 78 |
+
|
| 79 |
+
example
|
| 80 |
+
```python
|
| 81 |
+
class Dataset(RegistryMixin):
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# register with default name
|
| 86 |
+
@Dataset.register()
|
| 87 |
+
class ImageNetDataset(Dataset):
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
# load as "ImageNetDataset"
|
| 91 |
+
imagenet = Dataset.load("ImageNetDataset")
|
| 92 |
+
|
| 93 |
+
# register with custom name
|
| 94 |
+
@Dataset.register(name="cifar-dataset")
|
| 95 |
+
class Cifar(Dataset):
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
Note: the name will be standardized for lookup in the registry.
|
| 99 |
+
For example, if a class is registered as "cifar_dataset" or
|
| 100 |
+
"cifar dataset", it will be stored as "cifar-dataset". The user
|
| 101 |
+
will be able to load the class with any of the three name variants.
|
| 102 |
+
|
| 103 |
+
# register with multiple aliases
|
| 104 |
+
@Dataset.register(alias=["cifar-10-dataset", "cifar_100_dataset"])
|
| 105 |
+
class Cifar(Dataset):
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
# load as "cifar-dataset"
|
| 109 |
+
cifar = Dataset.load_from_registry("cifar-dataset")
|
| 110 |
+
|
| 111 |
+
# load from custom file that implements a dataset
|
| 112 |
+
mnist = Dataset.load_from_registry("/path/to/mnnist_dataset.py:MnistDataset")
|
| 113 |
+
```
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
# set to True in child class to add check that registered/retrieved values
|
| 117 |
+
# implement the class it is registered to
|
| 118 |
+
registry_requires_subclass: bool = False
|
| 119 |
+
|
| 120 |
+
@classmethod
|
| 121 |
+
def register(
|
| 122 |
+
cls, name: Optional[str] = None, alias: Union[List[str], str, None] = None
|
| 123 |
+
):
|
| 124 |
+
"""
|
| 125 |
+
Decorator for registering a value (ie class or function) wrapped by this
|
| 126 |
+
decorator to the base class (class that .register is called from)
|
| 127 |
+
|
| 128 |
+
:param name: name or list of names to register the wrapped value as,
|
| 129 |
+
defaults to value.__name__
|
| 130 |
+
:param alias: alias or list of aliases to register the wrapped value as,
|
| 131 |
+
defaults to None
|
| 132 |
+
:return: register decorator
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
def decorator(value: Any):
|
| 136 |
+
cls.register_value(value, name=name, alias=alias)
|
| 137 |
+
return value
|
| 138 |
+
|
| 139 |
+
return decorator
|
| 140 |
+
|
| 141 |
+
@classmethod
|
| 142 |
+
def register_value(
|
| 143 |
+
cls, value: Any, name: str, alias: Union[str, List[str], None] = None
|
| 144 |
+
):
|
| 145 |
+
"""
|
| 146 |
+
Registers the given value to the class `.register_value` is called from
|
| 147 |
+
:param value: value to register
|
| 148 |
+
:param name: name to register the wrapped value as,
|
| 149 |
+
defaults to value.__name__
|
| 150 |
+
:param alias: alias or list of aliases to register the wrapped value as,
|
| 151 |
+
defaults to None
|
| 152 |
+
"""
|
| 153 |
+
register(
|
| 154 |
+
parent_class=cls,
|
| 155 |
+
value=value,
|
| 156 |
+
name=name,
|
| 157 |
+
alias=alias,
|
| 158 |
+
require_subclass=cls.registry_requires_subclass,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
@classmethod
|
| 162 |
+
def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
|
| 163 |
+
"""
|
| 164 |
+
:param name: name of registered class to load
|
| 165 |
+
:param constructor_kwargs: arguments to pass to the constructor retrieved
|
| 166 |
+
from the registry
|
| 167 |
+
:return: loaded object registered to this class under the given name,
|
| 168 |
+
constructed with the given kwargs. Raises error if the name is
|
| 169 |
+
not found in the registry
|
| 170 |
+
"""
|
| 171 |
+
constructor = cls.get_value_from_registry(name=name)
|
| 172 |
+
return constructor(**constructor_kwargs)
|
| 173 |
+
|
| 174 |
+
@classmethod
|
| 175 |
+
def get_value_from_registry(cls, name: str):
|
| 176 |
+
"""
|
| 177 |
+
:param name: name to retrieve from the registry
|
| 178 |
+
:return: value from retrieved the registry for the given name, raises
|
| 179 |
+
error if not found
|
| 180 |
+
"""
|
| 181 |
+
return get_from_registry(
|
| 182 |
+
parent_class=cls,
|
| 183 |
+
name=name,
|
| 184 |
+
require_subclass=cls.registry_requires_subclass,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
@classmethod
|
| 188 |
+
def registered_names(cls) -> List[str]:
|
| 189 |
+
"""
|
| 190 |
+
:return: list of all names registered to this class
|
| 191 |
+
"""
|
| 192 |
+
return registered_names(cls)
|
| 193 |
+
|
| 194 |
+
@classmethod
|
| 195 |
+
def registered_aliases(cls) -> List[str]:
|
| 196 |
+
"""
|
| 197 |
+
:return: list of all aliases registered to this class
|
| 198 |
+
"""
|
| 199 |
+
return registered_aliases(cls)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def register(
|
| 203 |
+
parent_class: Type,
|
| 204 |
+
value: Any,
|
| 205 |
+
name: Optional[str] = None,
|
| 206 |
+
alias: Union[List[str], str, None] = None,
|
| 207 |
+
require_subclass: bool = False,
|
| 208 |
+
):
|
| 209 |
+
"""
|
| 210 |
+
:param parent_class: class to register the name under
|
| 211 |
+
:param value: the value to register
|
| 212 |
+
:param name: name to register the wrapped value as, defaults to value.__name__
|
| 213 |
+
:param alias: alias or list of aliases to register the wrapped value as,
|
| 214 |
+
defaults to None
|
| 215 |
+
:param require_subclass: require that value is a subclass of the class this
|
| 216 |
+
method is called from
|
| 217 |
+
"""
|
| 218 |
+
if name is None:
|
| 219 |
+
# default name
|
| 220 |
+
name = value.__name__
|
| 221 |
+
|
| 222 |
+
name = standardize_lookup_name(name)
|
| 223 |
+
alias = standardize_alias_name(alias)
|
| 224 |
+
register_alias(name=name, alias=alias, parent_class=parent_class)
|
| 225 |
+
|
| 226 |
+
if require_subclass:
|
| 227 |
+
_validate_subclass(parent_class, value)
|
| 228 |
+
|
| 229 |
+
if name in _REGISTRY[parent_class]:
|
| 230 |
+
# name already exists - raise error if two different values are attempting
|
| 231 |
+
# to share the same name
|
| 232 |
+
registered_value = _REGISTRY[parent_class][name]
|
| 233 |
+
if registered_value is not value:
|
| 234 |
+
raise RuntimeError(
|
| 235 |
+
f"Attempting to register name {name} as {value} "
|
| 236 |
+
f"however {name} has already been registered as {registered_value}"
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
_REGISTRY[parent_class][name] = value
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def get_from_registry(
|
| 243 |
+
parent_class: Type, name: str, require_subclass: bool = False
|
| 244 |
+
) -> Any:
|
| 245 |
+
"""
|
| 246 |
+
:param parent_class: class that the name is registered under
|
| 247 |
+
:param name: name to retrieve from the registry of the class
|
| 248 |
+
:param require_subclass: require that value is a subclass of the class this
|
| 249 |
+
method is called from
|
| 250 |
+
:return: value from retrieved the registry for the given name, raises
|
| 251 |
+
error if not found
|
| 252 |
+
"""
|
| 253 |
+
name = standardize_lookup_name(name)
|
| 254 |
+
|
| 255 |
+
if ":" in name:
|
| 256 |
+
# user specifying specific module to load and value to import
|
| 257 |
+
module_path, value_name = name.split(":")
|
| 258 |
+
retrieved_value = _import_and_get_value_from_module(module_path, value_name)
|
| 259 |
+
else:
|
| 260 |
+
# look up name in alias registry
|
| 261 |
+
name = _ALIAS_REGISTRY[parent_class].get(name, name)
|
| 262 |
+
# look up name in registry
|
| 263 |
+
retrieved_value = _REGISTRY[parent_class].get(name)
|
| 264 |
+
if retrieved_value is None:
|
| 265 |
+
raise KeyError(
|
| 266 |
+
f"Unable to find {name} registered under type {parent_class}.\n"
|
| 267 |
+
f"Registered values for {parent_class}: "
|
| 268 |
+
f"{registered_names(parent_class)}\n"
|
| 269 |
+
f"Registered aliases for {parent_class}: "
|
| 270 |
+
f"{registered_aliases(parent_class)}"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
if require_subclass:
|
| 274 |
+
_validate_subclass(parent_class, retrieved_value)
|
| 275 |
+
|
| 276 |
+
return retrieved_value
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def registered_names(parent_class: Type) -> List[str]:
|
| 280 |
+
"""
|
| 281 |
+
:param parent_class: class to look up the registry of
|
| 282 |
+
:return: all names registered to the given class
|
| 283 |
+
"""
|
| 284 |
+
return list(_REGISTRY[parent_class].keys())
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def registered_aliases(parent_class: Type) -> List[str]:
|
| 288 |
+
"""
|
| 289 |
+
:param parent_class: class to look up the registry of
|
| 290 |
+
:return: all aliases registered to the given class
|
| 291 |
+
"""
|
| 292 |
+
registered_aliases_plus_names = list(_ALIAS_REGISTRY[parent_class].keys())
|
| 293 |
+
registered_aliases = list(
|
| 294 |
+
set(registered_aliases_plus_names) - set(registered_names(parent_class))
|
| 295 |
+
)
|
| 296 |
+
return registered_aliases
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def register_alias(
|
| 300 |
+
name: str, parent_class: Type, alias: Union[str, List[str], None] = None
|
| 301 |
+
):
|
| 302 |
+
"""
|
| 303 |
+
Updates the mapping from the alias(es) to the given name.
|
| 304 |
+
If the alias is None, the name is used as the alias.
|
| 305 |
+
```
|
| 306 |
+
|
| 307 |
+
:param name: name that the alias refers to
|
| 308 |
+
:param parent_class: class that the name is registered under
|
| 309 |
+
:param alias: single alias or list of aliases that
|
| 310 |
+
refer to the name, defaults to None
|
| 311 |
+
"""
|
| 312 |
+
if alias is not None:
|
| 313 |
+
alias = alias if isinstance(alias, list) else [alias]
|
| 314 |
+
else:
|
| 315 |
+
alias = []
|
| 316 |
+
|
| 317 |
+
if name in alias:
|
| 318 |
+
raise KeyError(
|
| 319 |
+
f"Attempting to register alias {name}, "
|
| 320 |
+
f"that is identical to the standardized name: {name}."
|
| 321 |
+
)
|
| 322 |
+
alias.append(name)
|
| 323 |
+
|
| 324 |
+
for alias_name in alias:
|
| 325 |
+
if alias_name in _ALIAS_REGISTRY[parent_class]:
|
| 326 |
+
raise KeyError(
|
| 327 |
+
f"Attempting to register alias {alias_name} as {name} "
|
| 328 |
+
f"however {alias_name} has already been registered as "
|
| 329 |
+
f"{_ALIAS_REGISTRY[alias_name]}"
|
| 330 |
+
)
|
| 331 |
+
_ALIAS_REGISTRY[parent_class][alias_name] = name
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any:
|
| 335 |
+
# import the given module path and try to get the value_name if it is included
|
| 336 |
+
# in the module
|
| 337 |
+
|
| 338 |
+
# load module
|
| 339 |
+
spec = importlib.util.spec_from_file_location(
|
| 340 |
+
f"plugin_module_for_{value_name}", module_path
|
| 341 |
+
)
|
| 342 |
+
module = importlib.util.module_from_spec(spec)
|
| 343 |
+
spec.loader.exec_module(module)
|
| 344 |
+
|
| 345 |
+
# get value from module
|
| 346 |
+
value = getattr(module, value_name, None)
|
| 347 |
+
|
| 348 |
+
if not value:
|
| 349 |
+
raise RuntimeError(
|
| 350 |
+
f"Unable to find attribute {value_name} in module {module_path}"
|
| 351 |
+
)
|
| 352 |
+
return value
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def _validate_subclass(parent_class: Type, child_class: Type):
|
| 356 |
+
if not issubclass(child_class, parent_class):
|
| 357 |
+
raise ValueError(
|
| 358 |
+
f"class {child_class} is not a subclass of the class it is "
|
| 359 |
+
f"registered for: {parent_class}."
|
| 360 |
+
)
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# flake8: noqa
|
| 15 |
+
|
| 16 |
+
from .helpers import *
|
| 17 |
+
from .offload import *
|
| 18 |
+
from .permutations_24 import *
|
| 19 |
+
from .permute import *
|
| 20 |
+
from .safetensors_load import *
|
| 21 |
+
from .semi_structured_conversions import *
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (413 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/helpers.cpython-311.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/offload.cpython-311.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/permutations_24.cpython-311.pyc
ADDED
|
Binary file (3.01 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/permute.cpython-311.pyc
ADDED
|
Binary file (2.58 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/safetensors_load.cpython-311.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/semi_structured_conversions.cpython-311.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/helpers.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import warnings
|
| 16 |
+
from functools import wraps
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 18 |
+
|
| 19 |
+
import numpy
|
| 20 |
+
import torch
|
| 21 |
+
from transformers import AutoConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"infer_compressor_from_model_config",
|
| 26 |
+
"fix_fsdp_module_name",
|
| 27 |
+
"tensor_follows_mask_structure",
|
| 28 |
+
"replace_module",
|
| 29 |
+
"is_compressed_tensors_config",
|
| 30 |
+
"getattr_chain",
|
| 31 |
+
"deprecated",
|
| 32 |
+
"Aliasable",
|
| 33 |
+
"combine_shards",
|
| 34 |
+
"shard_tensor",
|
| 35 |
+
"pack_bitmasks",
|
| 36 |
+
"unpack_bitmasks",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def infer_compressor_from_model_config(
|
| 43 |
+
pretrained_model_name_or_path: str,
|
| 44 |
+
) -> Optional["ModelCompressor"]: # noqa: F821
|
| 45 |
+
"""
|
| 46 |
+
Given a path to a model config, extract a sparsity config if it exists and return
|
| 47 |
+
the associated ModelCompressor
|
| 48 |
+
|
| 49 |
+
:param pretrained_model_name_or_path: path to model config on disk or HF hub
|
| 50 |
+
:return: matching compressor if config contains a sparsity config
|
| 51 |
+
"""
|
| 52 |
+
from compressed_tensors.compressors import ModelCompressor
|
| 53 |
+
from compressed_tensors.config import CompressionConfig
|
| 54 |
+
|
| 55 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
| 56 |
+
sparsity_config = ModelCompressor.parse_sparsity_config(config)
|
| 57 |
+
if sparsity_config is None:
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
format = sparsity_config.get("format")
|
| 61 |
+
sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
|
| 62 |
+
compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
|
| 63 |
+
return compressor
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# TODO: There is already the same function in
|
| 67 |
+
# SparseML, should be moved to a shared location
|
| 68 |
+
# in the future
|
| 69 |
+
def fix_fsdp_module_name(name: str) -> str:
|
| 70 |
+
"""
|
| 71 |
+
Remove FSDP wrapper prefixes from a module name
|
| 72 |
+
Accounts for scenario where FSDP_WRAPPER_NAME is
|
| 73 |
+
at the end of the name, as well as in the middle.
|
| 74 |
+
:param name: name to strip
|
| 75 |
+
:return: stripped name
|
| 76 |
+
"""
|
| 77 |
+
return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
|
| 78 |
+
"." + FSDP_WRAPPER_NAME, ""
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
|
| 83 |
+
"""
|
| 84 |
+
:param tensor: tensor to check
|
| 85 |
+
:param mask: mask structure to check for, in the format "n:m"
|
| 86 |
+
:return: True if the tensor follows the mask structure, False otherwise.
|
| 87 |
+
Note, some weights can incidentally be zero, so we check for
|
| 88 |
+
atleast n zeros in each chunk of size m
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
n, m = tuple(map(int, mask.split(":")))
|
| 92 |
+
# Reshape the tensor into chunks of size m
|
| 93 |
+
tensor = tensor.view(-1, m)
|
| 94 |
+
|
| 95 |
+
# Count the number of zeros in each chunk
|
| 96 |
+
zero_counts = (tensor == 0).sum(dim=1)
|
| 97 |
+
|
| 98 |
+
# Check if the number of zeros in each chunk atleast n
|
| 99 |
+
# Greater than sign is needed as some weights can incidentally
|
| 100 |
+
# be zero
|
| 101 |
+
if not torch.all(zero_counts >= n).item():
|
| 102 |
+
raise ValueError()
|
| 103 |
+
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
|
| 108 |
+
if "." in name:
|
| 109 |
+
parent_name = name.rsplit(".", 1)[0]
|
| 110 |
+
child_name = name[len(parent_name) + 1 :]
|
| 111 |
+
parent = model.get_submodule(parent_name)
|
| 112 |
+
else:
|
| 113 |
+
parent_name = ""
|
| 114 |
+
parent = model
|
| 115 |
+
child_name = name
|
| 116 |
+
setattr(parent, child_name, new_module)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def is_compressed_tensors_config(compression_config: Any) -> bool:
|
| 120 |
+
"""
|
| 121 |
+
Returns True if CompressedTensorsConfig is available from transformers and
|
| 122 |
+
compression_config is an instance of CompressedTensorsConfig
|
| 123 |
+
|
| 124 |
+
See: https://github.com/huggingface/transformers/pull/31704
|
| 125 |
+
"""
|
| 126 |
+
try:
|
| 127 |
+
from transformers.utils.quantization_config import CompressedTensorsConfig
|
| 128 |
+
|
| 129 |
+
return isinstance(compression_config, CompressedTensorsConfig)
|
| 130 |
+
except ImportError:
|
| 131 |
+
return False
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
|
| 135 |
+
"""
|
| 136 |
+
Chain multiple getattr calls, separated by `.`
|
| 137 |
+
|
| 138 |
+
:param obj: base object whose attributes are being retrieved
|
| 139 |
+
:param chain_str: attribute names separated by `.`
|
| 140 |
+
:param default: default value, throw error otherwise
|
| 141 |
+
"""
|
| 142 |
+
if len(args) >= 1:
|
| 143 |
+
has_default = True
|
| 144 |
+
default = args[0]
|
| 145 |
+
elif "default" in kwargs:
|
| 146 |
+
has_default = True
|
| 147 |
+
default = kwargs["default"]
|
| 148 |
+
else:
|
| 149 |
+
has_default = False
|
| 150 |
+
|
| 151 |
+
attr_names = chain_str.split(".")
|
| 152 |
+
|
| 153 |
+
res = obj
|
| 154 |
+
for attr_name in attr_names:
|
| 155 |
+
if not hasattr(res, attr_name):
|
| 156 |
+
if has_default:
|
| 157 |
+
return default
|
| 158 |
+
else:
|
| 159 |
+
raise AttributeError(f"{res} object has no attribute {attr_name}")
|
| 160 |
+
res = getattr(res, attr_name)
|
| 161 |
+
|
| 162 |
+
return res
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
|
| 166 |
+
"""
|
| 167 |
+
Decorator to mark functions as deprecated
|
| 168 |
+
|
| 169 |
+
:param new_function: Function called in place of depreciated function
|
| 170 |
+
:param message: Depreciation message, replaces default depreciation message
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def decorator(func: Callable[[Any], Any]):
|
| 174 |
+
nonlocal message
|
| 175 |
+
|
| 176 |
+
if message is None:
|
| 177 |
+
message = (
|
| 178 |
+
f"{func.__name__} is deprecated and will be removed in a future release"
|
| 179 |
+
)
|
| 180 |
+
if future_name is not None:
|
| 181 |
+
message += f". Please use {future_name} instead."
|
| 182 |
+
|
| 183 |
+
@wraps(func)
|
| 184 |
+
def wrapped(*args, **kwargs):
|
| 185 |
+
warnings.warn(message, DeprecationWarning, stacklevel=2)
|
| 186 |
+
return func(*args, **kwargs)
|
| 187 |
+
|
| 188 |
+
return wrapped
|
| 189 |
+
|
| 190 |
+
return decorator
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class Aliasable:
|
| 194 |
+
"""
|
| 195 |
+
A mixin for enums to allow aliasing of enum members
|
| 196 |
+
|
| 197 |
+
Example:
|
| 198 |
+
>>> class MyClass(Aliasable, int, Enum):
|
| 199 |
+
>>> ...
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
@staticmethod
|
| 203 |
+
def get_aliases() -> Dict[str, str]:
|
| 204 |
+
raise NotImplementedError()
|
| 205 |
+
|
| 206 |
+
def __eq__(self, other):
|
| 207 |
+
if isinstance(other, self.__class__):
|
| 208 |
+
aliases = self.get_aliases()
|
| 209 |
+
return self.value == other.value or (
|
| 210 |
+
aliases.get(self.value, self.value)
|
| 211 |
+
== aliases.get(other.value, other.value)
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
aliases = self.get_aliases()
|
| 215 |
+
self_value = aliases.get(self.value, self.value)
|
| 216 |
+
other_value = aliases.get(other, other)
|
| 217 |
+
return self_value == other_value
|
| 218 |
+
|
| 219 |
+
def __hash__(self):
|
| 220 |
+
canonical_value = self.aliases.get(self.value, self.value)
|
| 221 |
+
return hash(canonical_value)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def shard_tensor(
|
| 225 |
+
tensor: torch.Tensor, shard_sizes: List[int], dim: int = 0
|
| 226 |
+
) -> List[torch.Tensor]:
|
| 227 |
+
"""
|
| 228 |
+
Shards a tensor into a list of tensors along a given dimension.
|
| 229 |
+
|
| 230 |
+
raises: ValueError: If the sum of shard_sizes does not match the
|
| 231 |
+
size of the tensor along the given dimension.
|
| 232 |
+
|
| 233 |
+
:param tensor: The input tensor to shard.
|
| 234 |
+
:param shard_sizes : List of sizes for each shard along the specified dimension.
|
| 235 |
+
:param dim : The dimension along which to shard the tensor.
|
| 236 |
+
:returns: A list of tensors sharded along the specified dimension.
|
| 237 |
+
"""
|
| 238 |
+
if sum(shard_sizes) != tensor.size(dim):
|
| 239 |
+
raise ValueError(
|
| 240 |
+
"Sum of shard_sizes must equal the size of the tensor "
|
| 241 |
+
"along the specified dimension."
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
shards = []
|
| 245 |
+
start_idx = 0
|
| 246 |
+
|
| 247 |
+
for size in shard_sizes:
|
| 248 |
+
end_idx = start_idx + size
|
| 249 |
+
shard = tensor.narrow(dim, start_idx, size)
|
| 250 |
+
shards.append(shard)
|
| 251 |
+
start_idx = end_idx
|
| 252 |
+
|
| 253 |
+
return shards
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def combine_shards(shards, dim=0):
|
| 257 |
+
"""
|
| 258 |
+
Combine decompressed shards along a given dimension using `narrow`.
|
| 259 |
+
|
| 260 |
+
:param shards: List of decompressed shard tensors.
|
| 261 |
+
:param dim: Dimension to combine along (default: 0).
|
| 262 |
+
:return: Combined decompressed tensor.
|
| 263 |
+
"""
|
| 264 |
+
if not shards:
|
| 265 |
+
raise ValueError("The list of shards is empty.")
|
| 266 |
+
|
| 267 |
+
# Assert that all shards have the same dtype
|
| 268 |
+
shard_dtypes = {shard.dtype for shard in shards}
|
| 269 |
+
if len(shard_dtypes) > 1:
|
| 270 |
+
raise ValueError("All shards must have the same dtype.")
|
| 271 |
+
|
| 272 |
+
# Determine the total shape of the combined tensor
|
| 273 |
+
total_shape = list(shards[0].shape)
|
| 274 |
+
total_shape[dim] = sum(shard.shape[dim] for shard in shards)
|
| 275 |
+
|
| 276 |
+
# Create the combined tensor
|
| 277 |
+
combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device)
|
| 278 |
+
|
| 279 |
+
# Fill the combined tensor using narrow
|
| 280 |
+
shard_offset = 0
|
| 281 |
+
for shard in shards:
|
| 282 |
+
shard_size = shard.shape[dim]
|
| 283 |
+
combined.narrow(dim, shard_offset, shard_size).copy_(shard)
|
| 284 |
+
shard_offset += shard_size
|
| 285 |
+
|
| 286 |
+
return combined
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
|
| 290 |
+
"""
|
| 291 |
+
Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
|
| 292 |
+
compressed to R x ceil(C/8)
|
| 293 |
+
|
| 294 |
+
:param bytemasks: mask tensor where each byte corresponds to a weight
|
| 295 |
+
:return: mask tensor where each bit corresounds to a weight
|
| 296 |
+
"""
|
| 297 |
+
packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
|
| 298 |
+
packed_bits_torch = torch.from_numpy(packed_bits_numpy)
|
| 299 |
+
|
| 300 |
+
return packed_bits_torch
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def unpack_bitmasks(
|
| 304 |
+
packed_bitmasks: torch.Tensor, original_shape: List[int]
|
| 305 |
+
) -> torch.Tensor:
|
| 306 |
+
"""
|
| 307 |
+
Converts a bitmask tensor back to a bytemask tensor for use during decompression
|
| 308 |
+
|
| 309 |
+
:param packed_bitmasks: mask tensor where each bit corresponds to a weight
|
| 310 |
+
:param original_shape: dense shape to decompress to
|
| 311 |
+
:return: boolean mask of weights in the original dense shape
|
| 312 |
+
"""
|
| 313 |
+
# Unpack the bits
|
| 314 |
+
unpacked_bits = numpy.unpackbits(
|
| 315 |
+
packed_bitmasks.cpu().numpy(),
|
| 316 |
+
axis=-1,
|
| 317 |
+
count=original_shape[-1],
|
| 318 |
+
bitorder="little",
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Reshape to match the original shape
|
| 322 |
+
unpacked_bitmasks_torch = torch.from_numpy(
|
| 323 |
+
unpacked_bits.reshape(original_shape).astype(bool)
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
return unpacked_bitmasks_torch
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/offload.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Utilities associated with offloading functionality provided by `accelerate`.
|
| 16 |
+
|
| 17 |
+
| ----------------------------------------------------------------------------------------------------- | # noqa: E501
|
| 18 |
+
| Operation | Without offloading support | With offloading support | # noqa: E501
|
| 19 |
+
| --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
|
| 20 |
+
| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
|
| 21 |
+
| Check | N/A | has_offloaded_params(module) | # noqa: E501
|
| 22 |
+
| Onload | N/A | with align_module_device(module) | # noqa: E501
|
| 23 |
+
| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
|
| 24 |
+
| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
|
| 25 |
+
| ----------------------------------------------------------------------------------------------------- | # noqa: E501
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import contextlib
|
| 29 |
+
from functools import wraps
|
| 30 |
+
from typing import Any, Callable, Dict, Literal, Optional, Union
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from accelerate.hooks import (
|
| 37 |
+
AlignDevicesHook,
|
| 38 |
+
add_hook_to_module,
|
| 39 |
+
remove_hook_from_module,
|
| 40 |
+
)
|
| 41 |
+
from accelerate.utils import (
|
| 42 |
+
OffloadedWeightsLoader,
|
| 43 |
+
PrefixedDataset,
|
| 44 |
+
set_module_tensor_to_device,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
_has_accelerate = True
|
| 48 |
+
except ImportError:
|
| 49 |
+
_has_accelerate = False
|
| 50 |
+
AlignDevicesHook = None
|
| 51 |
+
add_hook_to_module = None
|
| 52 |
+
remove_hook_from_module = None
|
| 53 |
+
OffloadedWeightsLoader = None
|
| 54 |
+
PrefixedDataset = None
|
| 55 |
+
set_module_tensor_to_device = None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
__all__ = [
|
| 59 |
+
"is_module_offloaded",
|
| 60 |
+
"get_execution_device",
|
| 61 |
+
"get_offloaded_device",
|
| 62 |
+
"update_prefix_dict",
|
| 63 |
+
"update_parameter_data",
|
| 64 |
+
"register_offload_parameter",
|
| 65 |
+
"update_offload_parameter",
|
| 66 |
+
"delete_offload_parameter",
|
| 67 |
+
"has_offloaded_params",
|
| 68 |
+
"disable_hf_hook",
|
| 69 |
+
"align_module_device",
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def check_accelerate(fallback: Any):
|
| 74 |
+
def decorator(func: Callable[[Any], Any]):
|
| 75 |
+
if not _has_accelerate:
|
| 76 |
+
|
| 77 |
+
@wraps(func)
|
| 78 |
+
def fallback_fn(*args, **kwargs):
|
| 79 |
+
return fallback
|
| 80 |
+
|
| 81 |
+
return fallback_fn
|
| 82 |
+
|
| 83 |
+
return func
|
| 84 |
+
|
| 85 |
+
return decorator
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
""" Candidates for Depreciation """
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@check_accelerate(fallback=False)
|
| 92 |
+
def is_module_offloaded(module: torch.nn.Module) -> bool:
|
| 93 |
+
return has_offloaded_params(module)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_execution_device(module: torch.nn.Module) -> torch.device:
|
| 97 |
+
"""
|
| 98 |
+
:param module: module to check
|
| 99 |
+
:return: device module is loaded onto during forward pass
|
| 100 |
+
"""
|
| 101 |
+
if has_offloaded_params(module):
|
| 102 |
+
return module._hf_hook.execution_device
|
| 103 |
+
device = next(module.parameters()).device
|
| 104 |
+
|
| 105 |
+
# offload only gets set for leaf modules, fallback to checking for device type
|
| 106 |
+
if device.type == "meta":
|
| 107 |
+
return module._hf_hook.execution_device
|
| 108 |
+
|
| 109 |
+
return device
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_offloaded_device(module: torch.nn.Module) -> torch.device:
|
| 113 |
+
"""
|
| 114 |
+
:param module: module to check
|
| 115 |
+
:return: device module is offloaded to onto after forward pass
|
| 116 |
+
"""
|
| 117 |
+
if has_offloaded_params(module):
|
| 118 |
+
first_key = list(module._hf_hook.weights_map.keys())[0]
|
| 119 |
+
prefix_dataset = module._hf_hook.weights_map.dataset
|
| 120 |
+
return prefix_dataset[first_key].device
|
| 121 |
+
return next(module.parameters()).device
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@check_accelerate(fallback=None)
|
| 125 |
+
def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor):
|
| 126 |
+
"""
|
| 127 |
+
Updates the offloaded state dict for a given module. Parameter named key is replaced
|
| 128 |
+
by data. This is neccesary because parameter updates for offloaded modules do not
|
| 129 |
+
persist automatically between loads. This function only affects the offloaded
|
| 130 |
+
state dict and not the current state of the loaded module.
|
| 131 |
+
|
| 132 |
+
:param module: module containing the parameter to update
|
| 133 |
+
:param key: name of parameter to update
|
| 134 |
+
:param data: tensor to update parameter with in the offloaded state dict
|
| 135 |
+
"""
|
| 136 |
+
if not has_offloaded_params(module):
|
| 137 |
+
raise ValueError("Prefix dict is only applicable to offloaded modules")
|
| 138 |
+
|
| 139 |
+
weights_map = module._hf_hook.weights_map
|
| 140 |
+
offload_to_weights_map(weights_map, key, data)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def update_parameter_data(
|
| 144 |
+
module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str
|
| 145 |
+
):
|
| 146 |
+
"""
|
| 147 |
+
Update the data of an existing parameter and its offload dict. Supports both
|
| 148 |
+
parameters of offloaded modules and non-offloaded modules
|
| 149 |
+
|
| 150 |
+
:param module: module containing the parameter to update
|
| 151 |
+
:param new_param_data: tensor to update parameter with
|
| 152 |
+
:param param_name: name of module parameter to update
|
| 153 |
+
"""
|
| 154 |
+
update_offload_parameter(module, param_name, new_param_data)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
""" Candidates for Upstreaming """
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def register_offload_parameter(
|
| 161 |
+
module: torch.nn.Module,
|
| 162 |
+
name: str,
|
| 163 |
+
parameter: torch.nn.Parameter,
|
| 164 |
+
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
| 165 |
+
):
|
| 166 |
+
"""
|
| 167 |
+
Register a parameter to the given module which may be offloaded
|
| 168 |
+
|
| 169 |
+
:param module: maybe offloaded module
|
| 170 |
+
:param name: name of newly registered parameter
|
| 171 |
+
:param parameter: parameter being registered
|
| 172 |
+
:param offload_device: device on which weight will be offloaded to. If None is
|
| 173 |
+
provided, then infer device from parameters on module
|
| 174 |
+
"""
|
| 175 |
+
has_onload = any(p.device != torch.device("meta") for p in module.parameters())
|
| 176 |
+
module.register_parameter(name, parameter)
|
| 177 |
+
|
| 178 |
+
if has_offloaded_params(module):
|
| 179 |
+
weights_map = module._hf_hook.weights_map
|
| 180 |
+
offload_to_weights_map(weights_map, name, parameter.data, offload_device)
|
| 181 |
+
if not has_onload:
|
| 182 |
+
set_module_tensor_to_device(module, name, "meta")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def update_offload_parameter(
|
| 186 |
+
module: torch.nn.Module,
|
| 187 |
+
name: str,
|
| 188 |
+
data: Optional[torch.Tensor],
|
| 189 |
+
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
Update the data of an existing parameter and its offload dict. Supports both
|
| 193 |
+
parameters of offloaded modules and non-offloaded modules
|
| 194 |
+
|
| 195 |
+
:param module: module containing the parameter to update
|
| 196 |
+
:param name: name of module parameter to update
|
| 197 |
+
:param data: tensor to update parameter with
|
| 198 |
+
:param offload_device: device on which weight will be offloaded to. If None is
|
| 199 |
+
provided, then infer device from parameters on module
|
| 200 |
+
"""
|
| 201 |
+
param = getattr(module, name)
|
| 202 |
+
data = data.to(param.dtype)
|
| 203 |
+
|
| 204 |
+
# copy data into onloaded parameter if applicable
|
| 205 |
+
if param.device != "meta":
|
| 206 |
+
param.data.copy_(data)
|
| 207 |
+
|
| 208 |
+
# update offload dict
|
| 209 |
+
if has_offloaded_params(module):
|
| 210 |
+
weights_map = module._hf_hook.weights_map
|
| 211 |
+
offload_to_weights_map(weights_map, name, data, offload_device)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def delete_offload_parameter(module: torch.nn.Module, name: str):
|
| 215 |
+
"""
|
| 216 |
+
Delete a parameter from a module which may be offloaded
|
| 217 |
+
|
| 218 |
+
:param module: maybe offloaded module
|
| 219 |
+
:param name: name of parameter being deleted
|
| 220 |
+
"""
|
| 221 |
+
delattr(module, name)
|
| 222 |
+
|
| 223 |
+
if has_offloaded_params(module):
|
| 224 |
+
weights_map = module._hf_hook.weights_map
|
| 225 |
+
delete_from_weights_map(weights_map, name)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@check_accelerate(fallback=contextlib.nullcontext())
|
| 229 |
+
@contextlib.contextmanager
|
| 230 |
+
def disable_hf_hook(module: torch.nn.Module):
|
| 231 |
+
hooks = {}
|
| 232 |
+
|
| 233 |
+
def collect_hooks(module):
|
| 234 |
+
nonlocal hooks
|
| 235 |
+
if hasattr(module, "_hf_hook"):
|
| 236 |
+
hooks[module] = module._hf_hook
|
| 237 |
+
remove_hook_from_module(module)
|
| 238 |
+
|
| 239 |
+
module.apply(collect_hooks)
|
| 240 |
+
|
| 241 |
+
yield
|
| 242 |
+
|
| 243 |
+
for submodule, hook in hooks.items():
|
| 244 |
+
add_hook_to_module(submodule, hook)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@check_accelerate(fallback=None)
|
| 248 |
+
def offload_to_weights_map(
|
| 249 |
+
weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
|
| 250 |
+
key: str,
|
| 251 |
+
value: torch.Tensor,
|
| 252 |
+
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
| 253 |
+
):
|
| 254 |
+
"""
|
| 255 |
+
Helper function which implements offloaded item assignment for PrefixedDataset,
|
| 256 |
+
OffloadedWeightsLoader, and Dict types.
|
| 257 |
+
|
| 258 |
+
:param weights_map: weight map to be updated with offload information
|
| 259 |
+
:param key: key used to identify weight location
|
| 260 |
+
:param value: weight being offloaded
|
| 261 |
+
:param offload_device: device on which weight will be offloaded to. If None is
|
| 262 |
+
provided, then infer device from parameters in weights_map
|
| 263 |
+
"""
|
| 264 |
+
if isinstance(weights_map, PrefixedDataset):
|
| 265 |
+
if offload_device == "disk":
|
| 266 |
+
raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
|
| 267 |
+
|
| 268 |
+
dataset = weights_map.dataset
|
| 269 |
+
key = f"{weights_map.prefix}{key}"
|
| 270 |
+
offload_to_weights_map(dataset, key, value, offload_device)
|
| 271 |
+
|
| 272 |
+
elif isinstance(weights_map, OffloadedWeightsLoader):
|
| 273 |
+
if key not in weights_map.all_keys:
|
| 274 |
+
weights_map.all_keys.append(key)
|
| 275 |
+
|
| 276 |
+
if len(weights_map.index) <= 0 and offload_device != "disk":
|
| 277 |
+
offload_to_weights_map(weights_map.state_dict, key, value, offload_device)
|
| 278 |
+
|
| 279 |
+
else:
|
| 280 |
+
raise NotImplementedError(
|
| 281 |
+
"Updating weights_map with disk offloading is not implemented yet"
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
elif isinstance(weights_map, dict):
|
| 285 |
+
if offload_device == "disk":
|
| 286 |
+
raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
|
| 287 |
+
|
| 288 |
+
# infer offload device
|
| 289 |
+
if offload_device is None:
|
| 290 |
+
if key in weights_map:
|
| 291 |
+
offload_device = weights_map[key].device
|
| 292 |
+
else:
|
| 293 |
+
tens = next(iter(weights_map.values()), None)
|
| 294 |
+
if tens is None:
|
| 295 |
+
raise ValueError(
|
| 296 |
+
"Cannot infer offload device from empty weights_map"
|
| 297 |
+
)
|
| 298 |
+
offload_device = tens.device
|
| 299 |
+
|
| 300 |
+
weights_map[key] = value.to(device=offload_device)
|
| 301 |
+
|
| 302 |
+
else:
|
| 303 |
+
raise NotImplementedError(
|
| 304 |
+
"Updating offload data not implemented for weights_map of type "
|
| 305 |
+
f"{type(weights_map)}"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
@check_accelerate(fallback=None)
|
| 310 |
+
def delete_from_weights_map(
|
| 311 |
+
weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
|
| 312 |
+
key: str,
|
| 313 |
+
):
|
| 314 |
+
if isinstance(weights_map, PrefixedDataset):
|
| 315 |
+
dataset = weights_map.dataset
|
| 316 |
+
key = f"{weights_map.prefix}{key}"
|
| 317 |
+
delete_from_weights_map(dataset, key)
|
| 318 |
+
|
| 319 |
+
elif isinstance(weights_map, OffloadedWeightsLoader):
|
| 320 |
+
if len(weights_map.index) <= 0:
|
| 321 |
+
delete_from_weights_map(weights_map.state_dict, key)
|
| 322 |
+
|
| 323 |
+
else:
|
| 324 |
+
raise NotImplementedError(
|
| 325 |
+
"Delete from weights_map with disk offloading is not implemented yet"
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
elif isinstance(weights_map, dict):
|
| 329 |
+
del weights_map[key]
|
| 330 |
+
|
| 331 |
+
else:
|
| 332 |
+
raise NotImplementedError(
|
| 333 |
+
"Updating offload data not implemented for weights_map of type "
|
| 334 |
+
f"{type(weights_map)}"
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
""" Upstreamed Functions """
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# introduced in accelerate v1.1.0
|
| 342 |
+
@check_accelerate(fallback=False)
|
| 343 |
+
def has_offloaded_params(module: torch.nn.Module) -> bool:
|
| 344 |
+
"""
|
| 345 |
+
Checks if a module has offloaded parameters by checking if the given module has a
|
| 346 |
+
AlignDevicesHook attached with offloading enabled
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
module (`torch.nn.Module`): The module to check for an offload hook.
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
bool: `True` if the module has an offload hook and offloading is enabled,
|
| 353 |
+
`False` otherwise.
|
| 354 |
+
"""
|
| 355 |
+
return (
|
| 356 |
+
hasattr(module, "_hf_hook")
|
| 357 |
+
and isinstance(module._hf_hook, AlignDevicesHook)
|
| 358 |
+
and module._hf_hook.offload
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
# introduced in accelerate v1.1.0
|
| 363 |
+
@check_accelerate(fallback=contextlib.nullcontext())
|
| 364 |
+
@contextlib.contextmanager
|
| 365 |
+
def align_module_device(
|
| 366 |
+
module: torch.nn.Module, execution_device: Optional[torch.device] = None
|
| 367 |
+
):
|
| 368 |
+
"""
|
| 369 |
+
Context manager that moves a module's parameters to the specified execution device.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
module (`torch.nn.Module`):
|
| 373 |
+
Module with parameters to align.
|
| 374 |
+
execution_device (`torch.device`, *optional*):
|
| 375 |
+
If provided, overrides the module's execution device within the context.
|
| 376 |
+
Otherwise, use hook execution device or pass
|
| 377 |
+
"""
|
| 378 |
+
if has_offloaded_params(module):
|
| 379 |
+
if execution_device is not None:
|
| 380 |
+
original_device = module._hf_hook.execution_device
|
| 381 |
+
module._hf_hook.execution_device = execution_device
|
| 382 |
+
|
| 383 |
+
try:
|
| 384 |
+
module._hf_hook.pre_forward(module)
|
| 385 |
+
yield
|
| 386 |
+
finally:
|
| 387 |
+
module._hf_hook.post_forward(module, None)
|
| 388 |
+
if execution_device is not None:
|
| 389 |
+
module._hf_hook.execution_device = original_device
|
| 390 |
+
|
| 391 |
+
elif execution_device is not None:
|
| 392 |
+
devices = {
|
| 393 |
+
name: param.device for name, param in module.named_parameters(recurse=False)
|
| 394 |
+
}
|
| 395 |
+
try:
|
| 396 |
+
for name in devices:
|
| 397 |
+
set_module_tensor_to_device(module, name, execution_device)
|
| 398 |
+
yield
|
| 399 |
+
finally:
|
| 400 |
+
for name, device in devices.items():
|
| 401 |
+
set_module_tensor_to_device(module, name, device)
|
| 402 |
+
|
| 403 |
+
else:
|
| 404 |
+
yield
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/permutations_24.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import numpy
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
__all__ = ["get_permutations_24"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Precompute permutations for Marlin24 weight and scale shuffling
|
| 24 |
+
# Originally implemented in nm-vllm/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py # noqa: E501
|
| 25 |
+
#
|
| 26 |
+
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight
|
| 27 |
+
# data so that it is compatible with the tensor-core format that is described here:
|
| 28 |
+
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
|
| 29 |
+
#
|
| 30 |
+
# As a result of this reordering, the vector loads inside the kernel will get the data
|
| 31 |
+
# as it is needed for tensor-core (without the need to use ldmatrix instructions)
|
| 32 |
+
def get_permutations_24(num_bits):
|
| 33 |
+
perm_list = []
|
| 34 |
+
for i in range(32):
|
| 35 |
+
perm1 = []
|
| 36 |
+
col = i // 4
|
| 37 |
+
col_o = col // 2
|
| 38 |
+
for block in [0, 1]:
|
| 39 |
+
for row in [
|
| 40 |
+
2 * (i % 4),
|
| 41 |
+
2 * (i % 4) + 1,
|
| 42 |
+
2 * (i % 4 + 4),
|
| 43 |
+
2 * (i % 4 + 4) + 1,
|
| 44 |
+
]:
|
| 45 |
+
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
|
| 46 |
+
for j in range(4):
|
| 47 |
+
perm_list.extend([p + 1 * j for p in perm1])
|
| 48 |
+
perm = numpy.array(perm_list)
|
| 49 |
+
|
| 50 |
+
if num_bits == 4:
|
| 51 |
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
| 52 |
+
elif num_bits == 8:
|
| 53 |
+
interleave = numpy.array([0, 2, 1, 3])
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
|
| 56 |
+
|
| 57 |
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
| 58 |
+
perm = torch.from_numpy(perm)
|
| 59 |
+
scale_perm = []
|
| 60 |
+
for i in range(8):
|
| 61 |
+
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
| 62 |
+
scale_perm_single = []
|
| 63 |
+
for i in range(8):
|
| 64 |
+
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
| 65 |
+
return perm, scale_perm, scale_perm_single
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/permute.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Set, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
__all__ = ["safe_permute"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# these datatypes are missing implementations required for standard permutation
|
| 24 |
+
_EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
| 28 |
+
"""
|
| 29 |
+
Perform out-of-place permutation without using torch.Tensor.index_put_,
|
| 30 |
+
whose implementation is missing for datatypes such as `torch.float8_e4m3fn`
|
| 31 |
+
|
| 32 |
+
:param value: tensor to permute
|
| 33 |
+
:param perm: permutation map
|
| 34 |
+
:param dim: dimension along which to apply permutation
|
| 35 |
+
:return: permuted value
|
| 36 |
+
"""
|
| 37 |
+
dtype_tuple = (value.dtype, value.device)
|
| 38 |
+
|
| 39 |
+
if dtype_tuple in _EXPERIMENTAL_DTYPES:
|
| 40 |
+
return _fallback_permute(value, perm, dim)
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
return value[tuple([slice(None)] * dim + [perm])]
|
| 44 |
+
except RuntimeError:
|
| 45 |
+
# Mark dtype as experimental if advanced indexing fails
|
| 46 |
+
_EXPERIMENTAL_DTYPES.add(dtype_tuple)
|
| 47 |
+
return _fallback_permute(value, perm, dim)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _fallback_permute(
|
| 51 |
+
value: torch.Tensor, perm: torch.Tensor, dim: int
|
| 52 |
+
) -> torch.Tensor:
|
| 53 |
+
"""
|
| 54 |
+
Fallback permutation method for experimental dtypes.
|
| 55 |
+
|
| 56 |
+
:param value: tensor to permute
|
| 57 |
+
:param perm: permutation map
|
| 58 |
+
:param dim: dimension along which to apply permutation
|
| 59 |
+
:return: permuted value
|
| 60 |
+
"""
|
| 61 |
+
value_ret = value.clone() # cannot use zeros_like b/c of missing impl.
|
| 62 |
+
orig_slices = [slice(None)] * (dim + 1)
|
| 63 |
+
perm_slices = [slice(None)] * (dim + 1)
|
| 64 |
+
|
| 65 |
+
for index, perm_index in enumerate(perm):
|
| 66 |
+
orig_slices[dim] = index
|
| 67 |
+
perm_slices[dim] = perm_index
|
| 68 |
+
value_ret[tuple(orig_slices)] = value[tuple(perm_slices)]
|
| 69 |
+
|
| 70 |
+
return value_ret
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/safetensors_load.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing,
|
| 10 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
import struct
|
| 19 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
from safetensors import safe_open
|
| 22 |
+
from torch import Tensor
|
| 23 |
+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"get_safetensors_folder",
|
| 28 |
+
"get_safetensors_header",
|
| 29 |
+
"match_param_name",
|
| 30 |
+
"merge_names",
|
| 31 |
+
"get_weight_mappings",
|
| 32 |
+
"get_nested_weight_mappings",
|
| 33 |
+
"get_nested_mappings_from_state_dict",
|
| 34 |
+
"get_quantization_state_dict",
|
| 35 |
+
"is_quantization_param",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
WeightMappingType = Dict[str, str]
|
| 39 |
+
NestedWeightMappingType = Dict[str, WeightMappingType]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_safetensors_folder(
|
| 43 |
+
pretrained_model_name_or_path: str, cache_dir: Optional[str] = None
|
| 44 |
+
) -> str:
|
| 45 |
+
"""
|
| 46 |
+
Given a Hugging Face stub or a local path, return the folder containing the
|
| 47 |
+
safetensors weight files
|
| 48 |
+
|
| 49 |
+
:param pretrained_model_name_or_path: local path to model or HF stub
|
| 50 |
+
:param cache_dir: optional cache dir to search through, if none is specified the
|
| 51 |
+
model will be searched for in the default TRANSFORMERS_CACHE
|
| 52 |
+
:return: local folder containing model data
|
| 53 |
+
"""
|
| 54 |
+
if os.path.exists(pretrained_model_name_or_path):
|
| 55 |
+
# argument is a path to a local folder
|
| 56 |
+
return os.path.abspath(pretrained_model_name_or_path)
|
| 57 |
+
|
| 58 |
+
safetensors_path = cached_file(
|
| 59 |
+
pretrained_model_name_or_path,
|
| 60 |
+
SAFE_WEIGHTS_NAME,
|
| 61 |
+
cache_dir=cache_dir,
|
| 62 |
+
_raise_exceptions_for_missing_entries=False,
|
| 63 |
+
)
|
| 64 |
+
index_path = cached_file(
|
| 65 |
+
pretrained_model_name_or_path,
|
| 66 |
+
SAFE_WEIGHTS_INDEX_NAME,
|
| 67 |
+
cache_dir=cache_dir,
|
| 68 |
+
_raise_exceptions_for_missing_entries=False,
|
| 69 |
+
)
|
| 70 |
+
if safetensors_path is not None:
|
| 71 |
+
# found a single cached safetensors file
|
| 72 |
+
return os.path.split(safetensors_path)[0]
|
| 73 |
+
if index_path is not None:
|
| 74 |
+
# found a cached safetensors weight index file
|
| 75 |
+
return os.path.split(index_path)[0]
|
| 76 |
+
|
| 77 |
+
# model weights could not be found locally or cached from HF Hub
|
| 78 |
+
raise ValueError(
|
| 79 |
+
"Could not locate safetensors weight or index file from "
|
| 80 |
+
f"{pretrained_model_name_or_path}."
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_safetensors_header(safetensors_path: str) -> Dict[str, str]:
|
| 85 |
+
"""
|
| 86 |
+
Extracts the metadata from a safetensors file as JSON
|
| 87 |
+
|
| 88 |
+
:param safetensors_path: path to a safetensors file
|
| 89 |
+
:return: dictionary of metadata extracted from the safetensors file
|
| 90 |
+
"""
|
| 91 |
+
with open(safetensors_path, "rb") as f:
|
| 92 |
+
length_of_header = struct.unpack("<Q", f.read(8))[0]
|
| 93 |
+
header_data = f.read(length_of_header)
|
| 94 |
+
header = json.loads(header_data)
|
| 95 |
+
|
| 96 |
+
return header
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def match_param_name(full_name: str, param_name: str) -> Optional[str]:
|
| 100 |
+
"""
|
| 101 |
+
Helper function extracting the uncompressed parameterized layer name from a
|
| 102 |
+
compressed name. Assumes the compressed name was merged using merge_names.
|
| 103 |
+
|
| 104 |
+
:param full_name: full name of parameter in compressed model
|
| 105 |
+
:param param_name: compression paramater name
|
| 106 |
+
:return: uncompressed name of the uncompressed parameterized layer
|
| 107 |
+
"""
|
| 108 |
+
pattern = r"^(.*)\." + param_name + r"$"
|
| 109 |
+
regex = re.findall(pattern, full_name)
|
| 110 |
+
if len(regex) == 0:
|
| 111 |
+
return None
|
| 112 |
+
return regex[0]
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def merge_names(parent_name: str, child_name: str) -> str:
|
| 116 |
+
"""
|
| 117 |
+
Helper function for merging an uncompressed parameterized layer name with a
|
| 118 |
+
compression parameter. Names merged with this function can then be parsed by
|
| 119 |
+
match_param_name.
|
| 120 |
+
|
| 121 |
+
:param parent_name: uncompressed parameterized layer name
|
| 122 |
+
:param child_name: compression parameter name
|
| 123 |
+
:return: merged compressed name
|
| 124 |
+
"""
|
| 125 |
+
return parent_name + "." + child_name
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:
|
| 129 |
+
"""
|
| 130 |
+
Takes a path to a state dict saved in safetensors format and returns a mapping
|
| 131 |
+
from parameterized layer name to file location.
|
| 132 |
+
|
| 133 |
+
{
|
| 134 |
+
layer.weight.bitmask: file_location,
|
| 135 |
+
layer.weight.row_offsets: file_location,
|
| 136 |
+
layer.weight.shape: file_location,
|
| 137 |
+
layer.weight.compressed: file_location
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
This generalizes to cases where the model is split into multiple safetensors files
|
| 141 |
+
|
| 142 |
+
:param path_to_model_or_tensors: path to directory that contains
|
| 143 |
+
safetensors (must contain either a single file or multiple files with an index),
|
| 144 |
+
or a path to a single safetensors file
|
| 145 |
+
:return: mapping of parameterized layer name to file location
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
if os.path.isfile(path_to_model_or_tensors):
|
| 149 |
+
# we have a single safetensors file to read
|
| 150 |
+
header = get_safetensors_header(path_to_model_or_tensors)
|
| 151 |
+
for key in header.keys():
|
| 152 |
+
header[key] = path_to_model_or_tensors
|
| 153 |
+
header.pop("__metadata__", None)
|
| 154 |
+
else:
|
| 155 |
+
# we have a directory with multiple safetensors files
|
| 156 |
+
safetensors_path = os.path.join(path_to_model_or_tensors, SAFE_WEIGHTS_NAME)
|
| 157 |
+
index_path = os.path.join(path_to_model_or_tensors, SAFE_WEIGHTS_INDEX_NAME)
|
| 158 |
+
if os.path.exists(safetensors_path):
|
| 159 |
+
# we have a single safetensors file to read
|
| 160 |
+
header = get_safetensors_header(safetensors_path)
|
| 161 |
+
for key in header.keys():
|
| 162 |
+
header[key] = SAFE_WEIGHTS_NAME
|
| 163 |
+
header.pop("__metadata__", None)
|
| 164 |
+
elif os.path.exists(index_path):
|
| 165 |
+
# we have multiple safetensors file, read from index
|
| 166 |
+
with open(index_path, "r", encoding="utf-8") as f:
|
| 167 |
+
index = json.load(f)
|
| 168 |
+
header = index["weight_map"]
|
| 169 |
+
else:
|
| 170 |
+
raise ValueError(
|
| 171 |
+
"Could not find a safetensors weight "
|
| 172 |
+
f"or index file at {path_to_model_or_tensors}"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# convert weight locations to full paths
|
| 176 |
+
for key, value in header.items():
|
| 177 |
+
header[key] = os.path.join(path_to_model_or_tensors, value)
|
| 178 |
+
|
| 179 |
+
return header
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def get_nested_weight_mappings(
|
| 183 |
+
model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False
|
| 184 |
+
) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]:
|
| 185 |
+
"""
|
| 186 |
+
Takes a path to a state dict saved in safetensors format and returns a nested
|
| 187 |
+
mapping from uncompressed parameterized layer names to the file locations of
|
| 188 |
+
each layer's compression parameters.
|
| 189 |
+
|
| 190 |
+
Example of the nested mapping:
|
| 191 |
+
layer: {
|
| 192 |
+
bitmask: file_location,
|
| 193 |
+
row_offsets: file_location,
|
| 194 |
+
shape: file_location,
|
| 195 |
+
compressed: file_location
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
If other parameters are found that do not match the nested parameters, they will
|
| 199 |
+
be returned in a separate dictionary only if return_unmatched_params is True.
|
| 200 |
+
This dictionary may be needed for cases where compressors are stacked (e.g.,
|
| 201 |
+
quantization compression followed by sparse compression).
|
| 202 |
+
|
| 203 |
+
Example of the unmatched params mapping:
|
| 204 |
+
{
|
| 205 |
+
layer.weight_scale: file_location,
|
| 206 |
+
layer.input_scale: file_location
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
This generalizes to cases where the model is split into multiple safetensors
|
| 210 |
+
files.
|
| 211 |
+
|
| 212 |
+
:param model_path: Path to the safetensors state dict, must contain either a
|
| 213 |
+
single safetensors file or multiple files with an index.
|
| 214 |
+
:param params_to_nest: List of parameter names to nest.
|
| 215 |
+
:param return_unmatched_params: If True, return a second dictionary containing
|
| 216 |
+
the remaining parameters that were not matched to the params_to_nest.
|
| 217 |
+
:return:
|
| 218 |
+
- If return_unmatched_params is False:
|
| 219 |
+
NestedWeightMappingType: A nested mapping of parameterized layer names to
|
| 220 |
+
file locations of each layer's compression parameters.
|
| 221 |
+
- If return_unmatched_params is True:
|
| 222 |
+
Tuple[NestedWeightMappingType, WeightMappingType]: A tuple containing:
|
| 223 |
+
- NestedWeightMappingType: A nested mapping of parameterized layer
|
| 224 |
+
names to file locations of each layer's compression parameters.
|
| 225 |
+
- WeightMappingType: A mapping of the remaining parameter names to
|
| 226 |
+
their file locations that were not matched to the params_to_nest.
|
| 227 |
+
"""
|
| 228 |
+
weight_mappings = get_weight_mappings(model_path)
|
| 229 |
+
nested_weight_mappings = {}
|
| 230 |
+
unmatched_params = {}
|
| 231 |
+
|
| 232 |
+
for key, file_location in weight_mappings.items():
|
| 233 |
+
matched = False
|
| 234 |
+
for param_name in params_to_nest:
|
| 235 |
+
dense_param = match_param_name(key, param_name)
|
| 236 |
+
if dense_param:
|
| 237 |
+
if dense_param not in nested_weight_mappings:
|
| 238 |
+
nested_weight_mappings[dense_param] = {}
|
| 239 |
+
nested_weight_mappings[dense_param][param_name] = file_location
|
| 240 |
+
matched = True
|
| 241 |
+
if return_unmatched_params and not matched:
|
| 242 |
+
unmatched_params[key] = file_location
|
| 243 |
+
|
| 244 |
+
if return_unmatched_params:
|
| 245 |
+
return nested_weight_mappings, unmatched_params
|
| 246 |
+
return nested_weight_mappings
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def get_nested_mappings_from_state_dict(
|
| 250 |
+
state_dict, params_to_nest
|
| 251 |
+
) -> NestedWeightMappingType:
|
| 252 |
+
"""
|
| 253 |
+
Takes a state dict and returns a nested mapping from uncompressed
|
| 254 |
+
parameterized layer names to the value of
|
| 255 |
+
each layer's compression parameters.
|
| 256 |
+
|
| 257 |
+
Example of the nested mapping:
|
| 258 |
+
layer: {
|
| 259 |
+
weight_scale: ...,
|
| 260 |
+
weight: ...,
|
| 261 |
+
zero_point: ...,
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
:param state_dict: state dict of the model
|
| 265 |
+
:param params_to_nest: List of parameter names to nest.
|
| 266 |
+
:return: Nested mapping of parameterized layer names to the value of
|
| 267 |
+
each layer's compression parameters.
|
| 268 |
+
"""
|
| 269 |
+
nested_weight_mappings = {}
|
| 270 |
+
for key in state_dict.keys():
|
| 271 |
+
for param_name in params_to_nest:
|
| 272 |
+
dense_param = match_param_name(key, param_name)
|
| 273 |
+
if dense_param:
|
| 274 |
+
if dense_param not in nested_weight_mappings:
|
| 275 |
+
nested_weight_mappings[dense_param] = {}
|
| 276 |
+
nested_weight_mappings[dense_param][param_name] = state_dict[key]
|
| 277 |
+
return nested_weight_mappings
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
|
| 281 |
+
weight_mappings = get_weight_mappings(model_path)
|
| 282 |
+
state_dict = {}
|
| 283 |
+
for weight_name, safe_path in weight_mappings.items():
|
| 284 |
+
if not is_quantization_param(weight_name):
|
| 285 |
+
continue
|
| 286 |
+
with safe_open(safe_path, framework="pt", device="cpu") as f:
|
| 287 |
+
state_dict[weight_name] = f.get_tensor(weight_name)
|
| 288 |
+
|
| 289 |
+
return state_dict
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def is_quantization_param(name: str) -> bool:
|
| 293 |
+
"""
|
| 294 |
+
Checks is a parameter name is associated with a quantization parameter
|
| 295 |
+
|
| 296 |
+
:param name: parameter name to check
|
| 297 |
+
:return: True if parameter name is a quantization parameter, else False
|
| 298 |
+
"""
|
| 299 |
+
if name.endswith("_scale"):
|
| 300 |
+
return True
|
| 301 |
+
if name.endswith("zero_point"):
|
| 302 |
+
return True
|
| 303 |
+
if name.endswith("g_idx"):
|
| 304 |
+
return True
|
| 305 |
+
|
| 306 |
+
return False
|
.venv/lib/python3.11/site-packages/compressed_tensors/utils/semi_structured_conversions.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
|
| 3 |
+
# Pulled from nm-vllm/vllm/model_executor/layers/quantization/utils/format_24.py
|
| 4 |
+
#
|
| 5 |
+
# flake8: noqa
|
| 6 |
+
# isort: skip_file
|
| 7 |
+
|
| 8 |
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing,
|
| 17 |
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
"sparse_semi_structured_from_dense_cutlass",
|
| 27 |
+
"sparse_semi_structured_to_dense_cutlass",
|
| 28 |
+
"mask_creator",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# This is PyTorch implementation of main part of reorder_meta()
|
| 33 |
+
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
| 34 |
+
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
| 35 |
+
# GEMM decides upon layout of this matrix, and at the moment for the
|
| 36 |
+
# sparse GEMM executed on tensor cores, this is layout described by
|
| 37 |
+
# ColumnMajorInterleaved<2> data structure, in
|
| 38 |
+
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
| 39 |
+
# reordering of meta matrix into meta_reordered matrix calculated
|
| 40 |
+
# according to these segments of CUTLASS code is re-implemented here.
|
| 41 |
+
# Note that this calculation produces offsets for scattering metadata
|
| 42 |
+
# matrix elements into reordered metadata matrix elements (or,
|
| 43 |
+
# equivalently, for gathering reordered metadata matrix element back
|
| 44 |
+
# into metadata matrix elements).
|
| 45 |
+
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
|
| 46 |
+
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
| 47 |
+
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
| 48 |
+
|
| 49 |
+
# Reorder the rows, then swizzle the 2x2 blocks.
|
| 50 |
+
group_x = 64
|
| 51 |
+
group_y = 32 if meta_dtype.itemsize == 2 else 16
|
| 52 |
+
|
| 53 |
+
dst_rows = (
|
| 54 |
+
dst_rows // group_x * group_x
|
| 55 |
+
+ (dst_rows % 2) * 2
|
| 56 |
+
+ (dst_rows % 8) // 4
|
| 57 |
+
+ ((dst_rows % group_y) % 4) // 2 * 32
|
| 58 |
+
+ ((dst_rows % group_x) // 8) * 4
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
|
| 62 |
+
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
|
| 63 |
+
dst_rows += topright - bottomleft
|
| 64 |
+
dst_cols -= topright - bottomleft
|
| 65 |
+
|
| 66 |
+
# Assumed that meta tensor is to be stored in CUTLASS
|
| 67 |
+
# InterleavedColumnMajor layout, and reverse engineered
|
| 68 |
+
# corresponding code to store values into this tensor.
|
| 69 |
+
interleave = 2
|
| 70 |
+
cols_maj = dst_cols // interleave
|
| 71 |
+
cols_min = dst_cols % interleave
|
| 72 |
+
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# This function converts dense matrix into sparse semi-structured
|
| 76 |
+
# representation, producing "compressed" matrix, in the layout used by
|
| 77 |
+
# CUTLASS backend, and corresponding metadata matrix.
|
| 78 |
+
def sparse_semi_structured_from_dense_cutlass(dense):
|
| 79 |
+
if dense.dim() != 2:
|
| 80 |
+
raise RuntimeError(
|
| 81 |
+
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
m, k = dense.shape
|
| 85 |
+
device = dense.device
|
| 86 |
+
|
| 87 |
+
meta_dtype = torch.int8
|
| 88 |
+
if dense.dtype == torch.int8:
|
| 89 |
+
meta_dtype = torch.int32
|
| 90 |
+
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
|
| 91 |
+
meta_dtype = torch.int16
|
| 92 |
+
else:
|
| 93 |
+
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
|
| 94 |
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
| 95 |
+
if quadbits_per_meta_elem not in (4, 8):
|
| 96 |
+
raise RuntimeError("Invalid number of elements per meta element calculated")
|
| 97 |
+
|
| 98 |
+
if meta_dtype == torch.int32:
|
| 99 |
+
if m % 16 != 0:
|
| 100 |
+
raise RuntimeError(
|
| 101 |
+
f"Number of rows of dense matrix {m} must be divisible by 16"
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
if m % 32 != 0:
|
| 105 |
+
raise RuntimeError(
|
| 106 |
+
f"Number of rows of dense matrix {m} must be divisible by 32"
|
| 107 |
+
)
|
| 108 |
+
if k % (4 * quadbits_per_meta_elem) != 0:
|
| 109 |
+
raise RuntimeError(
|
| 110 |
+
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if dense.dtype != torch.float:
|
| 114 |
+
ksparse = 4
|
| 115 |
+
dense_4 = dense.view(-1, k // ksparse, ksparse)
|
| 116 |
+
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
|
| 117 |
+
else:
|
| 118 |
+
ksparse = 2
|
| 119 |
+
dense_2 = dense.view(-1, k // ksparse, ksparse)
|
| 120 |
+
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
| 121 |
+
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
|
| 122 |
+
|
| 123 |
+
# Encoding quadruples of True/False values as follows:
|
| 124 |
+
# [True, True, False, False] -> 0b0100
|
| 125 |
+
# [True, False, True, False] -> 0b1000
|
| 126 |
+
# [False, True, True, False] -> 0b1001
|
| 127 |
+
# [True, False, False, True ] -> 0b1100
|
| 128 |
+
# [False, True, False, True ] -> 0b1101
|
| 129 |
+
# [False, False, True, True ] -> 0b1110
|
| 130 |
+
# Thus, lower two bits in the encoding are index of the True value
|
| 131 |
+
# at the lowest index in the quadruple, and the higher two bits in
|
| 132 |
+
# the encoding are index of the other True value in the quadruple.
|
| 133 |
+
# In case there are less than two True values, than False value or
|
| 134 |
+
# values at some index or indices are considered True for the
|
| 135 |
+
# encoding. In case there are more than two True values, then the
|
| 136 |
+
# excess True value(s) at some indices are considered False for
|
| 137 |
+
# the encoding. The exact encodings used for these cases are as
|
| 138 |
+
# follows:
|
| 139 |
+
# [False, False, False, False] -> 0b1110
|
| 140 |
+
# [False, False, False, True ] -> 0b1110
|
| 141 |
+
# [False, False, True, False] -> 0b1110
|
| 142 |
+
# [False, True, False, False] -> 0b1001
|
| 143 |
+
# [False, True, True, True ] -> 0b1101
|
| 144 |
+
# [True, False, False, False] -> 0b1000
|
| 145 |
+
# [True, False, True, True ] -> 0b1100
|
| 146 |
+
# [True, True, False, True ] -> 0b0100
|
| 147 |
+
# [True, True, True, False] -> 0b0100
|
| 148 |
+
# [True, True, True, True ] -> 0b0100
|
| 149 |
+
# These particular encodings are chosen, with the help of Espresso
|
| 150 |
+
# logic minimizer software, for the purpose of minimization of
|
| 151 |
+
# corresponding Boolean functions, that translate non-zero flags
|
| 152 |
+
# into encoding bits. Note also possible choices for the first
|
| 153 |
+
# and last of these encodings were limited only to (0b0100,
|
| 154 |
+
# 0b1110), in order to produce valid encodings for 1:2 sparsity
|
| 155 |
+
# case.
|
| 156 |
+
|
| 157 |
+
expr0 = m0 & m1
|
| 158 |
+
expr1 = ~m0 & m1
|
| 159 |
+
expr2 = ~m0 & ~m1
|
| 160 |
+
bit0 = expr1
|
| 161 |
+
bit1 = expr2
|
| 162 |
+
bit2 = expr0 | expr2 | m3
|
| 163 |
+
bit3 = expr1 | ~m1
|
| 164 |
+
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
|
| 165 |
+
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
| 166 |
+
|
| 167 |
+
if dense.dtype != torch.float:
|
| 168 |
+
sparse0 = dense_4.gather(
|
| 169 |
+
-1, idxs0.unsqueeze(-1)
|
| 170 |
+
) # type: ignore[possibly-undefined]
|
| 171 |
+
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
| 172 |
+
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
| 173 |
+
else:
|
| 174 |
+
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
|
| 175 |
+
m, k // 2
|
| 176 |
+
) # type: ignore[possibly-undefined]
|
| 177 |
+
|
| 178 |
+
meta_4 = idxs0 | (idxs1 << 2)
|
| 179 |
+
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
| 180 |
+
|
| 181 |
+
if quadbits_per_meta_elem == 4:
|
| 182 |
+
meta = (
|
| 183 |
+
meta_n[:, :, 0]
|
| 184 |
+
| (meta_n[:, :, 1] << 4)
|
| 185 |
+
| (meta_n[:, :, 2] << 8)
|
| 186 |
+
| (meta_n[:, :, 3] << 12)
|
| 187 |
+
)
|
| 188 |
+
elif quadbits_per_meta_elem == 8:
|
| 189 |
+
meta = (
|
| 190 |
+
meta_n[:, :, 0]
|
| 191 |
+
| (meta_n[:, :, 1] << 4)
|
| 192 |
+
| (meta_n[:, :, 2] << 8)
|
| 193 |
+
| (meta_n[:, :, 3] << 12)
|
| 194 |
+
| (meta_n[:, :, 4] << 16)
|
| 195 |
+
| (meta_n[:, :, 5] << 20)
|
| 196 |
+
| (meta_n[:, :, 6] << 24)
|
| 197 |
+
| (meta_n[:, :, 7] << 28)
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Reorder meta tensor elements.
|
| 201 |
+
meta_reordered = meta.new_empty(
|
| 202 |
+
(m * meta_ncols,)
|
| 203 |
+
) # type: ignore[possibly-undefined]
|
| 204 |
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
| 205 |
+
m, meta_ncols, meta_dtype, device
|
| 206 |
+
)
|
| 207 |
+
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
| 208 |
+
|
| 209 |
+
return (sparse, meta_reordered.view(m, meta_ncols))
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# This function performs reverse of the function above - it
|
| 213 |
+
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
| 214 |
+
# in the layout used by CUTLASS backend, and accompanying metadata
|
| 215 |
+
# matrix.
|
| 216 |
+
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
| 217 |
+
if sparse.dim() != 2:
|
| 218 |
+
raise RuntimeError(
|
| 219 |
+
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
m, k = sparse.shape
|
| 223 |
+
device = sparse.device
|
| 224 |
+
|
| 225 |
+
if meta_reordered.dim() != 2:
|
| 226 |
+
raise RuntimeError(
|
| 227 |
+
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
|
| 228 |
+
)
|
| 229 |
+
if meta_reordered.device != device:
|
| 230 |
+
raise RuntimeError(
|
| 231 |
+
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
meta_dtype = meta_reordered.dtype
|
| 235 |
+
if meta_dtype not in (torch.int16, torch.int32):
|
| 236 |
+
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
|
| 237 |
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
| 238 |
+
|
| 239 |
+
ksparse = 4 if sparse.dtype != torch.float else 2
|
| 240 |
+
|
| 241 |
+
meta_nrows, meta_ncols = meta_reordered.shape
|
| 242 |
+
if meta_nrows != m:
|
| 243 |
+
raise RuntimeError(
|
| 244 |
+
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
|
| 245 |
+
)
|
| 246 |
+
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
|
| 247 |
+
raise RuntimeError(
|
| 248 |
+
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
|
| 249 |
+
"expected according to the number of columns of meta matrix"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# Undo meta tensor elements reordering.
|
| 253 |
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
| 254 |
+
m, meta_ncols, meta_dtype, device
|
| 255 |
+
)
|
| 256 |
+
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
|
| 257 |
+
|
| 258 |
+
# Unpack sparse tensor back to original dense tensor, using
|
| 259 |
+
# information provided by meta tensor. Note that torch.float
|
| 260 |
+
# datatype is handled pretty much the same as
|
| 261 |
+
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
|
| 262 |
+
# value is encoded as if underlying 8 bytes contain four
|
| 263 |
+
# torch.half/torch.bfloat16 values, where either first two or last
|
| 264 |
+
# two are zeros.
|
| 265 |
+
meta_2 = torch.empty(
|
| 266 |
+
(m, meta_ncols, 2 * quadbits_per_meta_elem),
|
| 267 |
+
dtype=meta_dtype,
|
| 268 |
+
device=device,
|
| 269 |
+
)
|
| 270 |
+
if quadbits_per_meta_elem == 4:
|
| 271 |
+
meta_2[:, :, 0] = meta & 0b11
|
| 272 |
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
| 273 |
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
| 274 |
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
| 275 |
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
| 276 |
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
| 277 |
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
| 278 |
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
| 279 |
+
elif quadbits_per_meta_elem == 8:
|
| 280 |
+
meta_2[:, :, 0] = meta & 0b11
|
| 281 |
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
| 282 |
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
| 283 |
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
| 284 |
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
| 285 |
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
| 286 |
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
| 287 |
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
| 288 |
+
meta_2[:, :, 8] = (meta >> 16) & 0b11
|
| 289 |
+
meta_2[:, :, 9] = (meta >> 18) & 0b11
|
| 290 |
+
meta_2[:, :, 10] = (meta >> 20) & 0b11
|
| 291 |
+
meta_2[:, :, 11] = (meta >> 22) & 0b11
|
| 292 |
+
meta_2[:, :, 12] = (meta >> 24) & 0b11
|
| 293 |
+
meta_2[:, :, 13] = (meta >> 26) & 0b11
|
| 294 |
+
meta_2[:, :, 14] = (meta >> 28) & 0b11
|
| 295 |
+
meta_2[:, :, 15] = (meta >> 30) & 0b11
|
| 296 |
+
|
| 297 |
+
dense_offsets = meta_2.view(-1) + (
|
| 298 |
+
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
|
| 299 |
+
).view(-1, 1).repeat(1, 2).view(-1)
|
| 300 |
+
|
| 301 |
+
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
|
| 302 |
+
if sparse.dtype != torch.float:
|
| 303 |
+
# dense.scatter_(0, dense_offsets, sparse.view(-1))
|
| 304 |
+
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
|
| 305 |
+
else:
|
| 306 |
+
dense.view(torch.half).scatter_(
|
| 307 |
+
0, dense_offsets, sparse.view(torch.half).view(-1)
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
return dense.view(m, 2 * k)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def mask_creator(tensor):
|
| 314 |
+
"""
|
| 315 |
+
Class for creating N:M sparsity masks.
|
| 316 |
+
Masks will be created using the N:M ratio, where for every block of
|
| 317 |
+
M weights, N will be pruned based on ranked weight value. Each mask
|
| 318 |
+
will correspond to the given tensor.
|
| 319 |
+
|
| 320 |
+
:param N: The number of weights in a group to keep
|
| 321 |
+
:param M: The size of a weight group
|
| 322 |
+
"""
|
| 323 |
+
N = 2
|
| 324 |
+
M = 4
|
| 325 |
+
|
| 326 |
+
mask = None
|
| 327 |
+
# for i, tensor in enumerate(tensors):
|
| 328 |
+
if tensor.numel() % M != 0:
|
| 329 |
+
raise ValueError(
|
| 330 |
+
f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
num_groups = tensor.numel() // M
|
| 334 |
+
|
| 335 |
+
# N:M sparsity for linear layers
|
| 336 |
+
tensor_temp = tensor.detach().abs().reshape(num_groups, M)
|
| 337 |
+
index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
|
| 338 |
+
|
| 339 |
+
w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
|
| 340 |
+
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
|
| 341 |
+
|
| 342 |
+
return mask
|
.venv/lib/python3.11/site-packages/dotenv/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Optional
|
| 2 |
+
|
| 3 |
+
from .main import (dotenv_values, find_dotenv, get_key, load_dotenv, set_key,
|
| 4 |
+
unset_key)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_ipython_extension(ipython: Any) -> None:
|
| 8 |
+
from .ipython import load_ipython_extension
|
| 9 |
+
load_ipython_extension(ipython)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_cli_string(
|
| 13 |
+
path: Optional[str] = None,
|
| 14 |
+
action: Optional[str] = None,
|
| 15 |
+
key: Optional[str] = None,
|
| 16 |
+
value: Optional[str] = None,
|
| 17 |
+
quote: Optional[str] = None,
|
| 18 |
+
):
|
| 19 |
+
"""Returns a string suitable for running as a shell script.
|
| 20 |
+
|
| 21 |
+
Useful for converting a arguments passed to a fabric task
|
| 22 |
+
to be passed to a `local` or `run` command.
|
| 23 |
+
"""
|
| 24 |
+
command = ['dotenv']
|
| 25 |
+
if quote:
|
| 26 |
+
command.append(f'-q {quote}')
|
| 27 |
+
if path:
|
| 28 |
+
command.append(f'-f {path}')
|
| 29 |
+
if action:
|
| 30 |
+
command.append(action)
|
| 31 |
+
if key:
|
| 32 |
+
command.append(key)
|
| 33 |
+
if value:
|
| 34 |
+
if ' ' in value:
|
| 35 |
+
command.append(f'"{value}"')
|
| 36 |
+
else:
|
| 37 |
+
command.append(value)
|
| 38 |
+
|
| 39 |
+
return ' '.join(command).strip()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
__all__ = ['get_cli_string',
|
| 43 |
+
'load_dotenv',
|
| 44 |
+
'dotenv_values',
|
| 45 |
+
'get_key',
|
| 46 |
+
'set_key',
|
| 47 |
+
'unset_key',
|
| 48 |
+
'find_dotenv',
|
| 49 |
+
'load_ipython_extension']
|
.venv/lib/python3.11/site-packages/dotenv/__main__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Entry point for cli, enables execution with `python -m dotenv`"""
|
| 2 |
+
|
| 3 |
+
from .cli import cli
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
cli()
|
.venv/lib/python3.11/site-packages/dotenv/__pycache__/__main__.cpython-311.pyc
ADDED
|
Binary file (386 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/dotenv/__pycache__/cli.cpython-311.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/dotenv/__pycache__/ipython.cpython-311.pyc
ADDED
|
Binary file (2.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/dotenv/__pycache__/main.cpython-311.pyc
ADDED
|
Binary file (18.1 kB). View file
|
|
|