koichi12 commited on
Commit
f1e8896
·
verified ·
1 Parent(s): 4ac3d46

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/compressed_tensors/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/compressed_tensors/__pycache__/base.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/compressed_tensors/__pycache__/version.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/compressed_tensors/compressors/__pycache__/__init__.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/compressed_tensors/compressors/__pycache__/base.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/compressed_tensors/compressors/__pycache__/helpers.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  8. .venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/__pycache__/__init__.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/__pycache__/model_compressor.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/model_compressor.py +466 -0
  11. .venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/__pycache__/__init__.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/__pycache__/marlin_24.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/compressed_tensors/config/__init__.py +19 -0
  14. .venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/__init__.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/base.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/dense.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/sparse_24_bitmask.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/sparse_bitmask.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/compressed_tensors/config/base.py +111 -0
  20. .venv/lib/python3.11/site-packages/compressed_tensors/config/dense.py +36 -0
  21. .venv/lib/python3.11/site-packages/compressed_tensors/config/sparse_24_bitmask.py +40 -0
  22. .venv/lib/python3.11/site-packages/compressed_tensors/config/sparse_bitmask.py +36 -0
  23. .venv/lib/python3.11/site-packages/compressed_tensors/linear/__init__.py +13 -0
  24. .venv/lib/python3.11/site-packages/compressed_tensors/linear/__pycache__/__init__.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/compressed_tensors/linear/__pycache__/compressed_linear.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/compressed_tensors/linear/compressed_linear.py +89 -0
  27. .venv/lib/python3.11/site-packages/compressed_tensors/registry/__init__.py +17 -0
  28. .venv/lib/python3.11/site-packages/compressed_tensors/registry/__pycache__/__init__.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/compressed_tensors/registry/__pycache__/registry.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/compressed_tensors/registry/registry.py +360 -0
  31. .venv/lib/python3.11/site-packages/compressed_tensors/utils/__init__.py +21 -0
  32. .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/helpers.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/offload.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/permutations_24.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/permute.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/safetensors_load.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/semi_structured_conversions.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/compressed_tensors/utils/helpers.py +326 -0
  40. .venv/lib/python3.11/site-packages/compressed_tensors/utils/offload.py +404 -0
  41. .venv/lib/python3.11/site-packages/compressed_tensors/utils/permutations_24.py +65 -0
  42. .venv/lib/python3.11/site-packages/compressed_tensors/utils/permute.py +70 -0
  43. .venv/lib/python3.11/site-packages/compressed_tensors/utils/safetensors_load.py +306 -0
  44. .venv/lib/python3.11/site-packages/compressed_tensors/utils/semi_structured_conversions.py +342 -0
  45. .venv/lib/python3.11/site-packages/dotenv/__init__.py +49 -0
  46. .venv/lib/python3.11/site-packages/dotenv/__main__.py +6 -0
  47. .venv/lib/python3.11/site-packages/dotenv/__pycache__/__main__.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/dotenv/__pycache__/cli.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/dotenv/__pycache__/ipython.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/dotenv/__pycache__/main.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/compressed_tensors/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (447 Bytes). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/__pycache__/base.cpython-311.pyc ADDED
Binary file (487 Bytes). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/__pycache__/version.cpython-311.pyc ADDED
Binary file (1.26 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (435 Bytes). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/__pycache__/base.cpython-311.pyc ADDED
Binary file (7.69 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/__pycache__/helpers.cpython-311.pyc ADDED
Binary file (6.34 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+
17
+ from .model_compressor import *
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (264 Bytes). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/__pycache__/model_compressor.cpython-311.pyc ADDED
Binary file (19.5 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/model_compressors/model_compressor.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+ import operator
18
+ import os
19
+ import re
20
+ from contextlib import contextmanager
21
+ from copy import deepcopy
22
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union
23
+
24
+ import compressed_tensors
25
+ import torch
26
+ import transformers
27
+ from compressed_tensors.base import (
28
+ COMPRESSION_VERSION_NAME,
29
+ QUANTIZATION_CONFIG_NAME,
30
+ QUANTIZATION_METHOD_NAME,
31
+ SPARSITY_CONFIG_NAME,
32
+ )
33
+ from compressed_tensors.compressors.base import BaseCompressor
34
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
35
+ from compressed_tensors.quantization import (
36
+ DEFAULT_QUANTIZATION_METHOD,
37
+ QuantizationConfig,
38
+ QuantizationStatus,
39
+ apply_quantization_config,
40
+ load_pretrained_quantization,
41
+ )
42
+ from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
43
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
44
+ from compressed_tensors.quantization.utils import (
45
+ is_module_quantized,
46
+ iter_named_leaf_modules,
47
+ )
48
+ from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
49
+ from compressed_tensors.utils.helpers import (
50
+ fix_fsdp_module_name,
51
+ is_compressed_tensors_config,
52
+ )
53
+ from torch import Tensor
54
+ from torch.nn import Module
55
+ from tqdm import tqdm
56
+ from transformers import AutoConfig
57
+ from transformers.file_utils import CONFIG_NAME
58
+
59
+
60
+ __all__ = ["ModelCompressor", "map_modules_to_quant_args"]
61
+
62
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
63
+
64
+
65
+ if TYPE_CHECKING:
66
+ # dummy type if not available from transformers
67
+ CompressedTensorsConfig = TypeVar("CompressedTensorsConfig")
68
+
69
+
70
+ class ModelCompressor:
71
+ """
72
+ Handles compression and decompression of a model with a sparsity config and/or
73
+ quantization config.
74
+
75
+ Compression LifeCycle
76
+ - compressor = ModelCompressor.from_pretrained_model(model)
77
+ - compressed_state_dict = compressor.compress(model, state_dict)
78
+ - compressor.quantization_compressor.compress(model, state_dict)
79
+ - compressor.sparsity_compressor.compress(model, state_dict)
80
+ - model.save_pretrained(output_dir, state_dict=compressed_state_dict)
81
+ - compressor.update_config(output_dir)
82
+
83
+ Decompression LifeCycle
84
+ - compressor = ModelCompressor.from_pretrained(comp_model_path)
85
+ - model = AutoModel.from_pretrained(comp_model_path)
86
+ - compressor.decompress(comp_model_path, model)
87
+ - compressor.sparsity_compressor.decompress(comp_model_path, model)
88
+ - compressor.quantization_compressor.decompress(comp_model_path, model)
89
+
90
+ :param sparsity_config: config specifying sparsity compression parameters
91
+ :param quantization_config: config specifying quantization compression parameters
92
+ """
93
+
94
+ @classmethod
95
+ def from_pretrained(
96
+ cls,
97
+ pretrained_model_name_or_path: str,
98
+ **kwargs,
99
+ ) -> Optional["ModelCompressor"]:
100
+ """
101
+ Given a path to a model config, extract the sparsity and/or quantization
102
+ configs and load a ModelCompressor
103
+
104
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
105
+ :return: compressor for the configs, or None if model is not compressed
106
+ """
107
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
108
+ compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
109
+ return cls.from_compression_config(compression_config)
110
+
111
+ @classmethod
112
+ def from_compression_config(
113
+ cls,
114
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
115
+ ):
116
+ """
117
+ :param compression_config:
118
+ A compression or quantization config
119
+
120
+ The type is one of the following:
121
+ 1. A Dict found under either "quantization_config" or "compression_config"
122
+ keys in the config.json
123
+ 2. A CompressedTensorsConfig found under key "quantization_config" in HF
124
+ model config
125
+ :return: compressor for the configs, or None if model is not compressed
126
+ """
127
+ if compression_config is None:
128
+ return None
129
+
130
+ sparsity_config = cls.parse_sparsity_config(compression_config)
131
+ quantization_config = cls.parse_quantization_config(compression_config)
132
+ if sparsity_config is None and quantization_config is None:
133
+ return None
134
+
135
+ if sparsity_config is not None:
136
+ format = sparsity_config.get("format")
137
+ sparsity_config = SparsityCompressionConfig.load_from_registry(
138
+ format, **sparsity_config
139
+ )
140
+ if quantization_config is not None:
141
+ quantization_config = QuantizationConfig.model_validate(quantization_config)
142
+
143
+ return cls(
144
+ sparsity_config=sparsity_config, quantization_config=quantization_config
145
+ )
146
+
147
+ @classmethod
148
+ def from_pretrained_model(
149
+ cls,
150
+ model: Module,
151
+ sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
152
+ quantization_format: Optional[str] = None,
153
+ ) -> Optional["ModelCompressor"]:
154
+ """
155
+ Given a pytorch model and optional sparsity and/or quantization configs,
156
+ load the appropriate compressors
157
+
158
+ :param model: pytorch model to target for compression
159
+ :param sparsity_config: a filled in sparsity config or string corresponding
160
+ to a sparsity compression algorithm
161
+ :param quantization_format: string corresponding to a quantization compression
162
+ algorithm
163
+ :return: compressor for the configs, or None if model is not compressed
164
+ """
165
+ quantization_config = QuantizationConfig.from_pretrained(
166
+ model, format=quantization_format
167
+ )
168
+
169
+ if isinstance(sparsity_config, str): # we passed in a sparsity format
170
+ sparsity_config = SparsityCompressionConfig.load_from_registry(
171
+ sparsity_config
172
+ )
173
+
174
+ if sparsity_config is None and quantization_config is None:
175
+ return None
176
+
177
+ return cls(
178
+ sparsity_config=sparsity_config, quantization_config=quantization_config
179
+ )
180
+
181
+ @staticmethod
182
+ def parse_sparsity_config(
183
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
184
+ ) -> Union[Dict[str, Any], None]:
185
+ """
186
+ Parse sparsity config from quantization/compression config. Sparsity
187
+ config is nested inside q/c config
188
+
189
+ :param compression_config: quantization/compression config
190
+ :return: sparsity config
191
+ """
192
+ if compression_config is None:
193
+ return None
194
+
195
+ if is_compressed_tensors_config(compression_config):
196
+ s_config = compression_config.sparsity_config
197
+ return s_config.model_dump() if s_config is not None else None
198
+
199
+ return compression_config.get(SPARSITY_CONFIG_NAME, None)
200
+
201
+ @staticmethod
202
+ def parse_quantization_config(
203
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
204
+ ) -> Union[Dict[str, Any], None]:
205
+ """
206
+ Parse quantization config from quantization/compression config. The
207
+ quantization are all the fields that are not the sparsity config or
208
+ metadata fields
209
+
210
+ :param compression_config: quantization/compression config
211
+ :return: quantization config without sparsity config or metadata fields
212
+ """
213
+ if compression_config is None:
214
+ return None
215
+
216
+ if is_compressed_tensors_config(compression_config):
217
+ q_config = compression_config.quantization_config
218
+ return q_config.model_dump() if q_config is not None else None
219
+
220
+ quantization_config = deepcopy(compression_config)
221
+ quantization_config.pop(SPARSITY_CONFIG_NAME, None)
222
+
223
+ # some fields are required, even if a qconfig is not present
224
+ # pop them off and if nothing remains, then there is no qconfig
225
+ quant_method = quantization_config.pop(QUANTIZATION_METHOD_NAME, None)
226
+ _ = quantization_config.pop(COMPRESSION_VERSION_NAME, None)
227
+
228
+ if len(quantization_config) == 0:
229
+ return None
230
+
231
+ # replace popped off values
232
+ # note that version is discarded for now
233
+ if quant_method is not None:
234
+ quantization_config[QUANTIZATION_METHOD_NAME] = quant_method
235
+
236
+ return quantization_config
237
+
238
+ def __init__(
239
+ self,
240
+ sparsity_config: Optional[SparsityCompressionConfig] = None,
241
+ quantization_config: Optional[QuantizationConfig] = None,
242
+ ):
243
+ self.sparsity_config = sparsity_config
244
+ self.quantization_config = quantization_config
245
+ self.sparsity_compressor = None
246
+ self.quantization_compressor = None
247
+
248
+ if sparsity_config is not None:
249
+ self.sparsity_compressor = BaseCompressor.load_from_registry(
250
+ sparsity_config.format, config=sparsity_config
251
+ )
252
+ if quantization_config is not None:
253
+ self.quantization_compressor = BaseCompressor.load_from_registry(
254
+ quantization_config.format, config=quantization_config
255
+ )
256
+
257
+ def compress(
258
+ self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
259
+ ) -> Dict[str, Tensor]:
260
+ """
261
+ Compresses a dense state dict or model with sparsity and/or quantization
262
+
263
+ :param model: uncompressed model to compress
264
+ :param state_dict: optional uncompressed state_dict to insert into model
265
+ :return: compressed state dict
266
+ """
267
+ if state_dict is None:
268
+ state_dict = model.state_dict()
269
+
270
+ compressed_state_dict = state_dict
271
+
272
+ quantized_modules_to_args: Dict[
273
+ str, QuantizationArgs
274
+ ] = map_modules_to_quant_args(model)
275
+
276
+ if self.quantization_compressor is not None:
277
+ compressed_state_dict = self.quantization_compressor.compress(
278
+ state_dict, names_to_scheme=quantized_modules_to_args
279
+ )
280
+ if self.quantization_config.format != CompressionFormat.dense.value:
281
+ self.quantization_config.quantization_status = (
282
+ QuantizationStatus.COMPRESSED
283
+ )
284
+
285
+ if self.sparsity_compressor is not None:
286
+ sparse_compression_targets: Set[str] = expand_sparse_target_names(
287
+ model=model,
288
+ targets=self.sparsity_config.targets,
289
+ ignore=self.sparsity_config.ignore,
290
+ )
291
+ compressed_state_dict = self.sparsity_compressor.compress(
292
+ compressed_state_dict,
293
+ compression_targets=sparse_compression_targets,
294
+ )
295
+
296
+ # HACK: Override the dtype_byte_size function in transformers to
297
+ # support float8 types. Fix is posted upstream
298
+ # https://github.com/huggingface/transformers/pull/30488
299
+ transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
300
+
301
+ return compressed_state_dict
302
+
303
+ def decompress(self, model_path: str, model: Module):
304
+ """
305
+ Overwrites the weights in model with weights decompressed from model_path
306
+
307
+ :param model_path: path to compressed weights
308
+ :param model: pytorch model to load decompressed weights into
309
+ """
310
+ model_path = get_safetensors_folder(model_path)
311
+ sparse_decompressed = False
312
+
313
+ if (
314
+ self.sparsity_compressor is not None
315
+ and self.sparsity_config.format != CompressionFormat.dense.value
316
+ ):
317
+ # Sparse decompression is applied on the model_path
318
+ dense_gen = self.sparsity_compressor.decompress(model_path)
319
+ self._replace_weights(dense_gen, model)
320
+ setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
321
+ sparse_decompressed = True
322
+
323
+ if self.quantization_compressor is not None:
324
+ # Temporarily set quantization status to FROZEN to prevent
325
+ # quantization during apply_quantization_config. This ensures
326
+ # that the dtypes of the weights are not unintentionally updated.
327
+ # The status is restored after quantization params are loaded.
328
+ with override_quantization_status(
329
+ self.quantization_config, QuantizationStatus.FROZEN
330
+ ):
331
+ names_to_scheme = apply_quantization_config(
332
+ model, self.quantization_config
333
+ )
334
+ load_pretrained_quantization(model, model_path)
335
+
336
+ model_path_or_state_dict = (
337
+ model.state_dict() if sparse_decompressed else model_path
338
+ )
339
+
340
+ dense_gen = self.quantization_compressor.decompress(
341
+ model_path_or_state_dict, names_to_scheme=names_to_scheme
342
+ )
343
+ self._replace_weights(dense_gen, model)
344
+
345
+ def freeze_quantization_status(module):
346
+ module.quantization_status = QuantizationStatus.FROZEN
347
+
348
+ model.apply(freeze_quantization_status)
349
+ setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
350
+
351
+ def update_config(self, save_directory: str):
352
+ """
353
+ Update the model config located at save_directory with compression configs
354
+ for sparsity and/or quantization
355
+
356
+ :param save_directory: path to a folder containing a HF model config
357
+ """
358
+ if self.quantization_config is None and self.sparsity_config is None:
359
+ return
360
+
361
+ config_file_path = os.path.join(save_directory, CONFIG_NAME)
362
+ if not os.path.exists(config_file_path):
363
+ _LOGGER.warning(
364
+ f"Could not find a valid model config file in "
365
+ f"{save_directory}. Compression config will not be saved."
366
+ )
367
+ return
368
+
369
+ with open(config_file_path, "r") as config_file:
370
+ config_data = json.load(config_file)
371
+
372
+ # required metadata whenever a quantization or sparsity config is present
373
+ # overwrite previous config and version if already existing
374
+ config_data[QUANTIZATION_CONFIG_NAME] = {}
375
+ config_data[QUANTIZATION_CONFIG_NAME][
376
+ COMPRESSION_VERSION_NAME
377
+ ] = compressed_tensors.__version__
378
+ if self.quantization_config is not None:
379
+ self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
380
+ else:
381
+ config_data[QUANTIZATION_CONFIG_NAME][
382
+ QUANTIZATION_METHOD_NAME
383
+ ] = DEFAULT_QUANTIZATION_METHOD
384
+
385
+ # quantization and sparsity configs
386
+ if self.quantization_config is not None:
387
+ quant_config_data = self.quantization_config.model_dump()
388
+ config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
389
+ if self.sparsity_config is not None:
390
+ sparsity_config_data = self.sparsity_config.model_dump()
391
+ config_data[QUANTIZATION_CONFIG_NAME][
392
+ SPARSITY_CONFIG_NAME
393
+ ] = sparsity_config_data
394
+
395
+ with open(config_file_path, "w") as config_file:
396
+ json.dump(config_data, config_file, indent=2, sort_keys=True)
397
+
398
+ def _replace_weights(self, dense_weight_generator, model: Module):
399
+ """
400
+ Replace the weights of the model with the
401
+ provided dense weights.
402
+
403
+ This method iterates over the dense_weight_generator and
404
+ updates the corresponding weights in the model. If a parameter
405
+ name does not exist in the model, it will be skipped.
406
+
407
+ :param dense_weight_generator (generator): A generator that yields
408
+ tuples of (name, data), where 'name' is the parameter name and
409
+ 'data' is the updated param data
410
+ :param model: The model whose weights are to be updated.
411
+ """
412
+ for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
413
+ split_name = name.split(".")
414
+ prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
415
+ module = operator.attrgetter(prefix)(model)
416
+ if hasattr(module, param_name):
417
+ update_parameter_data(module, data, param_name)
418
+
419
+
420
+ def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
421
+ """
422
+ Given a pytorch model, map out the submodule name (usually linear layers)
423
+ to the QuantizationArgs
424
+
425
+ :param model: pytorch model
426
+ """
427
+ quantized_modules_to_args = {}
428
+ for name, submodule in iter_named_leaf_modules(model):
429
+ if is_module_quantized(submodule):
430
+ if submodule.quantization_scheme.weights is not None:
431
+ name = fix_fsdp_module_name(name)
432
+ quantized_modules_to_args[name] = submodule.quantization_scheme.weights
433
+
434
+ return quantized_modules_to_args
435
+
436
+
437
+ # HACK: Override the dtype_byte_size function in transformers to support float8 types
438
+ # Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
439
+ def new_dtype_byte_size(dtype):
440
+ if dtype == torch.bool:
441
+ return 1 / 8
442
+ bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
443
+ if bit_search is None:
444
+ raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
445
+ bit_size = int(bit_search.groups()[0])
446
+ return bit_size // 8
447
+
448
+
449
+ @contextmanager
450
+ def override_quantization_status(
451
+ config: QuantizationConfig, status: QuantizationStatus
452
+ ):
453
+ """
454
+ Within this context, the quantization status will be set to the
455
+ supplied status. After the context exits, the original status
456
+ will be restored.
457
+
458
+ :param config: the quantization config to override
459
+ :param status: the status to temporarily set
460
+ """
461
+ original_status = config.quantization_status
462
+ config.quantization_status = status
463
+ try:
464
+ yield
465
+ finally:
466
+ config.quantization_status = original_status
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (300 Bytes). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/compressors/sparse_quantized_compressors/__pycache__/marlin_24.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/config/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ from .base import *
17
+ from .dense import *
18
+ from .sparse_24_bitmask import *
19
+ from .sparse_bitmask import *
.venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (331 Bytes). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/base.cpython-311.pyc ADDED
Binary file (4.34 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/dense.cpython-311.pyc ADDED
Binary file (1.36 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/sparse_24_bitmask.cpython-311.pyc ADDED
Binary file (1.48 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/config/__pycache__/sparse_bitmask.cpython-311.pyc ADDED
Binary file (1.35 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/config/base.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum, unique
16
+ from typing import List, Optional
17
+
18
+ from compressed_tensors.registry import RegistryMixin
19
+ from pydantic import BaseModel
20
+
21
+
22
+ __all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]
23
+
24
+
25
+ @unique
26
+ class CompressionFormat(Enum):
27
+ dense = "dense"
28
+ sparse_bitmask = "sparse-bitmask"
29
+ sparse_24_bitmask = "sparse-24-bitmask"
30
+ int_quantized = "int-quantized"
31
+ float_quantized = "float-quantized"
32
+ naive_quantized = "naive-quantized"
33
+ pack_quantized = "pack-quantized"
34
+ marlin_24 = "marlin-24"
35
+
36
+
37
+ @unique
38
+ class SparsityStructure(Enum):
39
+ """
40
+ An enumeration to represent different sparsity structures.
41
+
42
+ Attributes
43
+ ----------
44
+ TWO_FOUR : str
45
+ Represents a 2:4 sparsity structure.
46
+ ZERO_ZERO : str
47
+ Represents a 0:0 sparsity structure.
48
+ UNSTRUCTURED : str
49
+ Represents an unstructured sparsity structure.
50
+
51
+ Examples
52
+ --------
53
+ >>> SparsityStructure('2:4')
54
+ <SparsityStructure.TWO_FOUR: '2:4'>
55
+
56
+ >>> SparsityStructure('unstructured')
57
+ <SparsityStructure.UNSTRUCTURED: 'unstructured'>
58
+
59
+ >>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
60
+ True
61
+
62
+ >>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
63
+ True
64
+
65
+ >>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
66
+ True
67
+
68
+ >>> SparsityStructure('invalid')
69
+ Traceback (most recent call last):
70
+ ...
71
+ ValueError: invalid is not a valid SparsityStructure
72
+ """
73
+
74
+ TWO_FOUR = "2:4"
75
+ UNSTRUCTURED = "unstructured"
76
+ ZERO_ZERO = "0:0"
77
+
78
+ def __new__(cls, value):
79
+ obj = object.__new__(cls)
80
+ obj._value_ = value.lower() if value is not None else value
81
+ return obj
82
+
83
+ @classmethod
84
+ def _missing_(cls, value):
85
+ # Handle None and case-insensitive values
86
+ if value is None:
87
+ return cls.UNSTRUCTURED
88
+ for member in cls:
89
+ if member.value == value.lower():
90
+ return member
91
+ raise ValueError(f"{value} is not a valid {cls.__name__}")
92
+
93
+
94
+ class SparsityCompressionConfig(RegistryMixin, BaseModel):
95
+ """
96
+ Base data class for storing sparsity compression parameters
97
+
98
+ :param format: name of compression format
99
+ :param targets: List of layer names or layer types that aren't sparse and should
100
+ be ignored during compression. By default, assume all layers are targeted
101
+ :param ignore: List of layer names (unique) to ignore from targets. Defaults to None
102
+ :param global_sparsity: average sparsity of the entire model
103
+ :param sparsity_structure: structure of the sparsity, such as
104
+ "unstructured", "2:4", "8:16" etc
105
+ """
106
+
107
+ format: str
108
+ targets: Optional[List[str]] = None
109
+ ignore: Optional[List[str]] = None
110
+ global_sparsity: Optional[float] = 0.0
111
+ sparsity_structure: Optional[str] = "unstructured"
.venv/lib/python3.11/site-packages/compressed_tensors/config/dense.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
18
+
19
+
20
+ __all__ = ["DenseSparsityConfig"]
21
+
22
+
23
+ @SparsityCompressionConfig.register(name=CompressionFormat.dense.value)
24
+ class DenseSparsityConfig(SparsityCompressionConfig):
25
+ """
26
+ Identity configuration for storing a sparse model in
27
+ an uncompressed dense format
28
+
29
+ :param global_sparsity: average sparsity of the entire model
30
+ :param sparsity_structure: structure of the sparsity, such as
31
+ "unstructured", "2:4", "8:16" etc
32
+ """
33
+
34
+ format: str = CompressionFormat.dense.value
35
+ global_sparsity: Optional[float] = 0.0
36
+ sparsity_structure: Optional[str] = "unstructured"
.venv/lib/python3.11/site-packages/compressed_tensors/config/sparse_24_bitmask.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from compressed_tensors.config import (
18
+ CompressionFormat,
19
+ SparsityCompressionConfig,
20
+ SparsityStructure,
21
+ )
22
+
23
+
24
+ __all__ = ["Sparse24BitMaskConfig"]
25
+
26
+
27
+ @SparsityCompressionConfig.register(name=CompressionFormat.sparse_24_bitmask.value)
28
+ class Sparse24BitMaskConfig(SparsityCompressionConfig):
29
+ """
30
+ Configuration for storing a 24 sparse model using
31
+ bytemask compression
32
+
33
+ :param global_sparsity: average sparsity of the entire model
34
+ :param sparsity_structure: structure of the sparsity, should always be
35
+ "2:4" for this compression format
36
+ """
37
+
38
+ format: str = CompressionFormat.sparse_24_bitmask.value
39
+ global_sparsity: Optional[float] = 0.0
40
+ sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value
.venv/lib/python3.11/site-packages/compressed_tensors/config/sparse_bitmask.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
18
+
19
+
20
+ __all__ = ["BitmaskConfig"]
21
+
22
+
23
+ @SparsityCompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
24
+ class BitmaskConfig(SparsityCompressionConfig):
25
+ """
26
+ Configuration for storing a sparse model using
27
+ bitmask compression
28
+
29
+ :param global_sparsity: average sparsity of the entire model
30
+ :param sparsity_structure: structure of the sparsity, such as
31
+ "unstructured", "2:4", "8:16" etc
32
+ """
33
+
34
+ format: str = CompressionFormat.sparse_bitmask.value
35
+ global_sparsity: Optional[float] = 0.0
36
+ sparsity_structure: Optional[str] = "unstructured"
.venv/lib/python3.11/site-packages/compressed_tensors/linear/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
.venv/lib/python3.11/site-packages/compressed_tensors/linear/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (198 Bytes). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/linear/__pycache__/compressed_linear.cpython-311.pyc ADDED
Binary file (3.63 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/linear/compressed_linear.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.compressors.base import BaseCompressor
19
+ from compressed_tensors.quantization import (
20
+ QuantizationScheme,
21
+ QuantizationStatus,
22
+ initialize_module_for_quantization,
23
+ )
24
+ from torch import Tensor
25
+ from torch.nn import Parameter
26
+ from torch.nn.functional import linear
27
+ from torch.nn.modules import Linear
28
+
29
+
30
+ class CompressedLinear(Linear):
31
+ """
32
+ Wrapper module for running a compressed forward pass of a quantized Linear module.
33
+ The wrapped layer will decompressed on each forward call.
34
+
35
+ :param module: dense linear module to replace
36
+ :param quantization_scheme: quantization config for the module to wrap
37
+ :param quantization_format: compression format module is stored as
38
+ """
39
+
40
+ @classmethod
41
+ @torch.no_grad()
42
+ def from_linear(
43
+ cls,
44
+ module: Linear,
45
+ quantization_scheme: QuantizationScheme,
46
+ quantization_format: str,
47
+ ):
48
+ module.__class__ = CompressedLinear
49
+ module.compressor = BaseCompressor.load_from_registry(quantization_format)
50
+ device = next(module.parameters()).device
51
+
52
+ # this will initialize all the scales and zero points
53
+ initialize_module_for_quantization(
54
+ module, quantization_scheme, force_zero_point=False
55
+ )
56
+
57
+ # get the shape and dtype of compressed parameters
58
+ compression_params: Dict[str, Tuple] = module.compressor.compression_param_info(
59
+ module.weight.shape, quantization_scheme.weights
60
+ )
61
+
62
+ # no need for this once quantization is initialized, will be replaced
63
+ # with the compressed parameter
64
+ delattr(module, "weight")
65
+
66
+ # populate compressed weights and quantization parameters
67
+ for name, (shape, dtype) in compression_params.items():
68
+ param = Parameter(
69
+ torch.empty(shape, device=device, dtype=dtype), requires_grad=False
70
+ )
71
+ module.register_parameter(name, param)
72
+
73
+ # mark module as compressed
74
+ module.quantization_status = QuantizationStatus.COMPRESSED
75
+
76
+ # handles case where forward is wrapped in new_forward by accelerate hooks
77
+ if hasattr(module, "_old_forward"):
78
+ module._old_forward = CompressedLinear.forward.__get__(
79
+ module, CompressedLinear
80
+ )
81
+
82
+ return module
83
+
84
+ def forward(self, input: Tensor) -> Tensor:
85
+ """
86
+ Decompresses the weight, then runs the wrapped forward pass
87
+ """
88
+ uncompressed_weight = self.compressor.decompress_module(self)
89
+ return linear(input, uncompressed_weight, self.bias)
.venv/lib/python3.11/site-packages/compressed_tensors/registry/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .registry import *
.venv/lib/python3.11/site-packages/compressed_tensors/registry/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (235 Bytes). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/registry/__pycache__/registry.cpython-311.pyc ADDED
Binary file (14.4 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/registry/registry.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Universal registry to support registration and loading of child classes and plugins
17
+ of neuralmagic utilities
18
+ """
19
+
20
+ import importlib
21
+ from collections import defaultdict
22
+ from typing import Any, Dict, List, Optional, Type, Union
23
+
24
+
25
+ __all__ = [
26
+ "RegistryMixin",
27
+ "register",
28
+ "get_from_registry",
29
+ "registered_names",
30
+ "registered_aliases",
31
+ "standardize_lookup_name",
32
+ ]
33
+
34
+
35
+ _ALIAS_REGISTRY: Dict[Type, Dict[str, str]] = defaultdict(dict)
36
+ _REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict)
37
+
38
+
39
+ def standardize_lookup_name(name: str) -> str:
40
+ """
41
+ Standardize the given name for lookup in the registry.
42
+ This will replace all underscores and spaces with hyphens and
43
+ convert the name to lowercase.
44
+
45
+ example:
46
+ ```
47
+ standardize_lookup_name("Foo_bar baz") == "foo-bar-baz"
48
+ ```
49
+
50
+ :param name: name to standardize
51
+ :return: standardized name
52
+ """
53
+ return name.replace("_", "-").replace(" ", "-").lower()
54
+
55
+
56
+ def standardize_alias_name(
57
+ name: Union[None, str, List[str]]
58
+ ) -> Union[None, str, List[str]]:
59
+ if name is None:
60
+ return None
61
+ elif isinstance(name, str):
62
+ return standardize_lookup_name(name)
63
+ else: # isinstance(name, list)
64
+ return [standardize_lookup_name(n) for n in name]
65
+
66
+
67
+ class RegistryMixin:
68
+ """
69
+ Universal registry to support registration and loading of child classes and plugins
70
+ of neuralmagic utilities.
71
+
72
+ Classes that require a registry or plugins may add the `RegistryMixin` and use
73
+ `register` and `load` as the main entrypoints for adding new implementations and
74
+ loading requested values from its registry.
75
+
76
+ If a class should only have its child classes in its registry, the class should
77
+ set the static attribute `registry_requires_subclass` to True
78
+
79
+ example
80
+ ```python
81
+ class Dataset(RegistryMixin):
82
+ pass
83
+
84
+
85
+ # register with default name
86
+ @Dataset.register()
87
+ class ImageNetDataset(Dataset):
88
+ pass
89
+
90
+ # load as "ImageNetDataset"
91
+ imagenet = Dataset.load("ImageNetDataset")
92
+
93
+ # register with custom name
94
+ @Dataset.register(name="cifar-dataset")
95
+ class Cifar(Dataset):
96
+ pass
97
+
98
+ Note: the name will be standardized for lookup in the registry.
99
+ For example, if a class is registered as "cifar_dataset" or
100
+ "cifar dataset", it will be stored as "cifar-dataset". The user
101
+ will be able to load the class with any of the three name variants.
102
+
103
+ # register with multiple aliases
104
+ @Dataset.register(alias=["cifar-10-dataset", "cifar_100_dataset"])
105
+ class Cifar(Dataset):
106
+ pass
107
+
108
+ # load as "cifar-dataset"
109
+ cifar = Dataset.load_from_registry("cifar-dataset")
110
+
111
+ # load from custom file that implements a dataset
112
+ mnist = Dataset.load_from_registry("/path/to/mnnist_dataset.py:MnistDataset")
113
+ ```
114
+ """
115
+
116
+ # set to True in child class to add check that registered/retrieved values
117
+ # implement the class it is registered to
118
+ registry_requires_subclass: bool = False
119
+
120
+ @classmethod
121
+ def register(
122
+ cls, name: Optional[str] = None, alias: Union[List[str], str, None] = None
123
+ ):
124
+ """
125
+ Decorator for registering a value (ie class or function) wrapped by this
126
+ decorator to the base class (class that .register is called from)
127
+
128
+ :param name: name or list of names to register the wrapped value as,
129
+ defaults to value.__name__
130
+ :param alias: alias or list of aliases to register the wrapped value as,
131
+ defaults to None
132
+ :return: register decorator
133
+ """
134
+
135
+ def decorator(value: Any):
136
+ cls.register_value(value, name=name, alias=alias)
137
+ return value
138
+
139
+ return decorator
140
+
141
+ @classmethod
142
+ def register_value(
143
+ cls, value: Any, name: str, alias: Union[str, List[str], None] = None
144
+ ):
145
+ """
146
+ Registers the given value to the class `.register_value` is called from
147
+ :param value: value to register
148
+ :param name: name to register the wrapped value as,
149
+ defaults to value.__name__
150
+ :param alias: alias or list of aliases to register the wrapped value as,
151
+ defaults to None
152
+ """
153
+ register(
154
+ parent_class=cls,
155
+ value=value,
156
+ name=name,
157
+ alias=alias,
158
+ require_subclass=cls.registry_requires_subclass,
159
+ )
160
+
161
+ @classmethod
162
+ def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
163
+ """
164
+ :param name: name of registered class to load
165
+ :param constructor_kwargs: arguments to pass to the constructor retrieved
166
+ from the registry
167
+ :return: loaded object registered to this class under the given name,
168
+ constructed with the given kwargs. Raises error if the name is
169
+ not found in the registry
170
+ """
171
+ constructor = cls.get_value_from_registry(name=name)
172
+ return constructor(**constructor_kwargs)
173
+
174
+ @classmethod
175
+ def get_value_from_registry(cls, name: str):
176
+ """
177
+ :param name: name to retrieve from the registry
178
+ :return: value from retrieved the registry for the given name, raises
179
+ error if not found
180
+ """
181
+ return get_from_registry(
182
+ parent_class=cls,
183
+ name=name,
184
+ require_subclass=cls.registry_requires_subclass,
185
+ )
186
+
187
+ @classmethod
188
+ def registered_names(cls) -> List[str]:
189
+ """
190
+ :return: list of all names registered to this class
191
+ """
192
+ return registered_names(cls)
193
+
194
+ @classmethod
195
+ def registered_aliases(cls) -> List[str]:
196
+ """
197
+ :return: list of all aliases registered to this class
198
+ """
199
+ return registered_aliases(cls)
200
+
201
+
202
+ def register(
203
+ parent_class: Type,
204
+ value: Any,
205
+ name: Optional[str] = None,
206
+ alias: Union[List[str], str, None] = None,
207
+ require_subclass: bool = False,
208
+ ):
209
+ """
210
+ :param parent_class: class to register the name under
211
+ :param value: the value to register
212
+ :param name: name to register the wrapped value as, defaults to value.__name__
213
+ :param alias: alias or list of aliases to register the wrapped value as,
214
+ defaults to None
215
+ :param require_subclass: require that value is a subclass of the class this
216
+ method is called from
217
+ """
218
+ if name is None:
219
+ # default name
220
+ name = value.__name__
221
+
222
+ name = standardize_lookup_name(name)
223
+ alias = standardize_alias_name(alias)
224
+ register_alias(name=name, alias=alias, parent_class=parent_class)
225
+
226
+ if require_subclass:
227
+ _validate_subclass(parent_class, value)
228
+
229
+ if name in _REGISTRY[parent_class]:
230
+ # name already exists - raise error if two different values are attempting
231
+ # to share the same name
232
+ registered_value = _REGISTRY[parent_class][name]
233
+ if registered_value is not value:
234
+ raise RuntimeError(
235
+ f"Attempting to register name {name} as {value} "
236
+ f"however {name} has already been registered as {registered_value}"
237
+ )
238
+ else:
239
+ _REGISTRY[parent_class][name] = value
240
+
241
+
242
+ def get_from_registry(
243
+ parent_class: Type, name: str, require_subclass: bool = False
244
+ ) -> Any:
245
+ """
246
+ :param parent_class: class that the name is registered under
247
+ :param name: name to retrieve from the registry of the class
248
+ :param require_subclass: require that value is a subclass of the class this
249
+ method is called from
250
+ :return: value from retrieved the registry for the given name, raises
251
+ error if not found
252
+ """
253
+ name = standardize_lookup_name(name)
254
+
255
+ if ":" in name:
256
+ # user specifying specific module to load and value to import
257
+ module_path, value_name = name.split(":")
258
+ retrieved_value = _import_and_get_value_from_module(module_path, value_name)
259
+ else:
260
+ # look up name in alias registry
261
+ name = _ALIAS_REGISTRY[parent_class].get(name, name)
262
+ # look up name in registry
263
+ retrieved_value = _REGISTRY[parent_class].get(name)
264
+ if retrieved_value is None:
265
+ raise KeyError(
266
+ f"Unable to find {name} registered under type {parent_class}.\n"
267
+ f"Registered values for {parent_class}: "
268
+ f"{registered_names(parent_class)}\n"
269
+ f"Registered aliases for {parent_class}: "
270
+ f"{registered_aliases(parent_class)}"
271
+ )
272
+
273
+ if require_subclass:
274
+ _validate_subclass(parent_class, retrieved_value)
275
+
276
+ return retrieved_value
277
+
278
+
279
+ def registered_names(parent_class: Type) -> List[str]:
280
+ """
281
+ :param parent_class: class to look up the registry of
282
+ :return: all names registered to the given class
283
+ """
284
+ return list(_REGISTRY[parent_class].keys())
285
+
286
+
287
+ def registered_aliases(parent_class: Type) -> List[str]:
288
+ """
289
+ :param parent_class: class to look up the registry of
290
+ :return: all aliases registered to the given class
291
+ """
292
+ registered_aliases_plus_names = list(_ALIAS_REGISTRY[parent_class].keys())
293
+ registered_aliases = list(
294
+ set(registered_aliases_plus_names) - set(registered_names(parent_class))
295
+ )
296
+ return registered_aliases
297
+
298
+
299
+ def register_alias(
300
+ name: str, parent_class: Type, alias: Union[str, List[str], None] = None
301
+ ):
302
+ """
303
+ Updates the mapping from the alias(es) to the given name.
304
+ If the alias is None, the name is used as the alias.
305
+ ```
306
+
307
+ :param name: name that the alias refers to
308
+ :param parent_class: class that the name is registered under
309
+ :param alias: single alias or list of aliases that
310
+ refer to the name, defaults to None
311
+ """
312
+ if alias is not None:
313
+ alias = alias if isinstance(alias, list) else [alias]
314
+ else:
315
+ alias = []
316
+
317
+ if name in alias:
318
+ raise KeyError(
319
+ f"Attempting to register alias {name}, "
320
+ f"that is identical to the standardized name: {name}."
321
+ )
322
+ alias.append(name)
323
+
324
+ for alias_name in alias:
325
+ if alias_name in _ALIAS_REGISTRY[parent_class]:
326
+ raise KeyError(
327
+ f"Attempting to register alias {alias_name} as {name} "
328
+ f"however {alias_name} has already been registered as "
329
+ f"{_ALIAS_REGISTRY[alias_name]}"
330
+ )
331
+ _ALIAS_REGISTRY[parent_class][alias_name] = name
332
+
333
+
334
+ def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any:
335
+ # import the given module path and try to get the value_name if it is included
336
+ # in the module
337
+
338
+ # load module
339
+ spec = importlib.util.spec_from_file_location(
340
+ f"plugin_module_for_{value_name}", module_path
341
+ )
342
+ module = importlib.util.module_from_spec(spec)
343
+ spec.loader.exec_module(module)
344
+
345
+ # get value from module
346
+ value = getattr(module, value_name, None)
347
+
348
+ if not value:
349
+ raise RuntimeError(
350
+ f"Unable to find attribute {value_name} in module {module_path}"
351
+ )
352
+ return value
353
+
354
+
355
+ def _validate_subclass(parent_class: Type, child_class: Type):
356
+ if not issubclass(child_class, parent_class):
357
+ raise ValueError(
358
+ f"class {child_class} is not a subclass of the class it is "
359
+ f"registered for: {parent_class}."
360
+ )
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+ from .helpers import *
17
+ from .offload import *
18
+ from .permutations_24 import *
19
+ from .permute import *
20
+ from .safetensors_load import *
21
+ from .semi_structured_conversions import *
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (413 Bytes). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/helpers.cpython-311.pyc ADDED
Binary file (13.8 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/offload.cpython-311.pyc ADDED
Binary file (17.5 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/permutations_24.cpython-311.pyc ADDED
Binary file (3.01 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/permute.cpython-311.pyc ADDED
Binary file (2.58 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/safetensors_load.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/utils/__pycache__/semi_structured_conversions.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
.venv/lib/python3.11/site-packages/compressed_tensors/utils/helpers.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from functools import wraps
17
+ from typing import Any, Callable, Dict, List, Optional
18
+
19
+ import numpy
20
+ import torch
21
+ from transformers import AutoConfig
22
+
23
+
24
+ __all__ = [
25
+ "infer_compressor_from_model_config",
26
+ "fix_fsdp_module_name",
27
+ "tensor_follows_mask_structure",
28
+ "replace_module",
29
+ "is_compressed_tensors_config",
30
+ "getattr_chain",
31
+ "deprecated",
32
+ "Aliasable",
33
+ "combine_shards",
34
+ "shard_tensor",
35
+ "pack_bitmasks",
36
+ "unpack_bitmasks",
37
+ ]
38
+
39
+ FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
40
+
41
+
42
+ def infer_compressor_from_model_config(
43
+ pretrained_model_name_or_path: str,
44
+ ) -> Optional["ModelCompressor"]: # noqa: F821
45
+ """
46
+ Given a path to a model config, extract a sparsity config if it exists and return
47
+ the associated ModelCompressor
48
+
49
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
50
+ :return: matching compressor if config contains a sparsity config
51
+ """
52
+ from compressed_tensors.compressors import ModelCompressor
53
+ from compressed_tensors.config import CompressionConfig
54
+
55
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
56
+ sparsity_config = ModelCompressor.parse_sparsity_config(config)
57
+ if sparsity_config is None:
58
+ return None
59
+
60
+ format = sparsity_config.get("format")
61
+ sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
62
+ compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
63
+ return compressor
64
+
65
+
66
+ # TODO: There is already the same function in
67
+ # SparseML, should be moved to a shared location
68
+ # in the future
69
+ def fix_fsdp_module_name(name: str) -> str:
70
+ """
71
+ Remove FSDP wrapper prefixes from a module name
72
+ Accounts for scenario where FSDP_WRAPPER_NAME is
73
+ at the end of the name, as well as in the middle.
74
+ :param name: name to strip
75
+ :return: stripped name
76
+ """
77
+ return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
78
+ "." + FSDP_WRAPPER_NAME, ""
79
+ )
80
+
81
+
82
+ def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
83
+ """
84
+ :param tensor: tensor to check
85
+ :param mask: mask structure to check for, in the format "n:m"
86
+ :return: True if the tensor follows the mask structure, False otherwise.
87
+ Note, some weights can incidentally be zero, so we check for
88
+ atleast n zeros in each chunk of size m
89
+ """
90
+
91
+ n, m = tuple(map(int, mask.split(":")))
92
+ # Reshape the tensor into chunks of size m
93
+ tensor = tensor.view(-1, m)
94
+
95
+ # Count the number of zeros in each chunk
96
+ zero_counts = (tensor == 0).sum(dim=1)
97
+
98
+ # Check if the number of zeros in each chunk atleast n
99
+ # Greater than sign is needed as some weights can incidentally
100
+ # be zero
101
+ if not torch.all(zero_counts >= n).item():
102
+ raise ValueError()
103
+
104
+ return True
105
+
106
+
107
+ def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
108
+ if "." in name:
109
+ parent_name = name.rsplit(".", 1)[0]
110
+ child_name = name[len(parent_name) + 1 :]
111
+ parent = model.get_submodule(parent_name)
112
+ else:
113
+ parent_name = ""
114
+ parent = model
115
+ child_name = name
116
+ setattr(parent, child_name, new_module)
117
+
118
+
119
+ def is_compressed_tensors_config(compression_config: Any) -> bool:
120
+ """
121
+ Returns True if CompressedTensorsConfig is available from transformers and
122
+ compression_config is an instance of CompressedTensorsConfig
123
+
124
+ See: https://github.com/huggingface/transformers/pull/31704
125
+ """
126
+ try:
127
+ from transformers.utils.quantization_config import CompressedTensorsConfig
128
+
129
+ return isinstance(compression_config, CompressedTensorsConfig)
130
+ except ImportError:
131
+ return False
132
+
133
+
134
+ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
135
+ """
136
+ Chain multiple getattr calls, separated by `.`
137
+
138
+ :param obj: base object whose attributes are being retrieved
139
+ :param chain_str: attribute names separated by `.`
140
+ :param default: default value, throw error otherwise
141
+ """
142
+ if len(args) >= 1:
143
+ has_default = True
144
+ default = args[0]
145
+ elif "default" in kwargs:
146
+ has_default = True
147
+ default = kwargs["default"]
148
+ else:
149
+ has_default = False
150
+
151
+ attr_names = chain_str.split(".")
152
+
153
+ res = obj
154
+ for attr_name in attr_names:
155
+ if not hasattr(res, attr_name):
156
+ if has_default:
157
+ return default
158
+ else:
159
+ raise AttributeError(f"{res} object has no attribute {attr_name}")
160
+ res = getattr(res, attr_name)
161
+
162
+ return res
163
+
164
+
165
+ def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
166
+ """
167
+ Decorator to mark functions as deprecated
168
+
169
+ :param new_function: Function called in place of depreciated function
170
+ :param message: Depreciation message, replaces default depreciation message
171
+ """
172
+
173
+ def decorator(func: Callable[[Any], Any]):
174
+ nonlocal message
175
+
176
+ if message is None:
177
+ message = (
178
+ f"{func.__name__} is deprecated and will be removed in a future release"
179
+ )
180
+ if future_name is not None:
181
+ message += f". Please use {future_name} instead."
182
+
183
+ @wraps(func)
184
+ def wrapped(*args, **kwargs):
185
+ warnings.warn(message, DeprecationWarning, stacklevel=2)
186
+ return func(*args, **kwargs)
187
+
188
+ return wrapped
189
+
190
+ return decorator
191
+
192
+
193
+ class Aliasable:
194
+ """
195
+ A mixin for enums to allow aliasing of enum members
196
+
197
+ Example:
198
+ >>> class MyClass(Aliasable, int, Enum):
199
+ >>> ...
200
+ """
201
+
202
+ @staticmethod
203
+ def get_aliases() -> Dict[str, str]:
204
+ raise NotImplementedError()
205
+
206
+ def __eq__(self, other):
207
+ if isinstance(other, self.__class__):
208
+ aliases = self.get_aliases()
209
+ return self.value == other.value or (
210
+ aliases.get(self.value, self.value)
211
+ == aliases.get(other.value, other.value)
212
+ )
213
+ else:
214
+ aliases = self.get_aliases()
215
+ self_value = aliases.get(self.value, self.value)
216
+ other_value = aliases.get(other, other)
217
+ return self_value == other_value
218
+
219
+ def __hash__(self):
220
+ canonical_value = self.aliases.get(self.value, self.value)
221
+ return hash(canonical_value)
222
+
223
+
224
+ def shard_tensor(
225
+ tensor: torch.Tensor, shard_sizes: List[int], dim: int = 0
226
+ ) -> List[torch.Tensor]:
227
+ """
228
+ Shards a tensor into a list of tensors along a given dimension.
229
+
230
+ raises: ValueError: If the sum of shard_sizes does not match the
231
+ size of the tensor along the given dimension.
232
+
233
+ :param tensor: The input tensor to shard.
234
+ :param shard_sizes : List of sizes for each shard along the specified dimension.
235
+ :param dim : The dimension along which to shard the tensor.
236
+ :returns: A list of tensors sharded along the specified dimension.
237
+ """
238
+ if sum(shard_sizes) != tensor.size(dim):
239
+ raise ValueError(
240
+ "Sum of shard_sizes must equal the size of the tensor "
241
+ "along the specified dimension."
242
+ )
243
+
244
+ shards = []
245
+ start_idx = 0
246
+
247
+ for size in shard_sizes:
248
+ end_idx = start_idx + size
249
+ shard = tensor.narrow(dim, start_idx, size)
250
+ shards.append(shard)
251
+ start_idx = end_idx
252
+
253
+ return shards
254
+
255
+
256
+ def combine_shards(shards, dim=0):
257
+ """
258
+ Combine decompressed shards along a given dimension using `narrow`.
259
+
260
+ :param shards: List of decompressed shard tensors.
261
+ :param dim: Dimension to combine along (default: 0).
262
+ :return: Combined decompressed tensor.
263
+ """
264
+ if not shards:
265
+ raise ValueError("The list of shards is empty.")
266
+
267
+ # Assert that all shards have the same dtype
268
+ shard_dtypes = {shard.dtype for shard in shards}
269
+ if len(shard_dtypes) > 1:
270
+ raise ValueError("All shards must have the same dtype.")
271
+
272
+ # Determine the total shape of the combined tensor
273
+ total_shape = list(shards[0].shape)
274
+ total_shape[dim] = sum(shard.shape[dim] for shard in shards)
275
+
276
+ # Create the combined tensor
277
+ combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device)
278
+
279
+ # Fill the combined tensor using narrow
280
+ shard_offset = 0
281
+ for shard in shards:
282
+ shard_size = shard.shape[dim]
283
+ combined.narrow(dim, shard_offset, shard_size).copy_(shard)
284
+ shard_offset += shard_size
285
+
286
+ return combined
287
+
288
+
289
+ def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
290
+ """
291
+ Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
292
+ compressed to R x ceil(C/8)
293
+
294
+ :param bytemasks: mask tensor where each byte corresponds to a weight
295
+ :return: mask tensor where each bit corresounds to a weight
296
+ """
297
+ packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
298
+ packed_bits_torch = torch.from_numpy(packed_bits_numpy)
299
+
300
+ return packed_bits_torch
301
+
302
+
303
+ def unpack_bitmasks(
304
+ packed_bitmasks: torch.Tensor, original_shape: List[int]
305
+ ) -> torch.Tensor:
306
+ """
307
+ Converts a bitmask tensor back to a bytemask tensor for use during decompression
308
+
309
+ :param packed_bitmasks: mask tensor where each bit corresponds to a weight
310
+ :param original_shape: dense shape to decompress to
311
+ :return: boolean mask of weights in the original dense shape
312
+ """
313
+ # Unpack the bits
314
+ unpacked_bits = numpy.unpackbits(
315
+ packed_bitmasks.cpu().numpy(),
316
+ axis=-1,
317
+ count=original_shape[-1],
318
+ bitorder="little",
319
+ )
320
+
321
+ # Reshape to match the original shape
322
+ unpacked_bitmasks_torch = torch.from_numpy(
323
+ unpacked_bits.reshape(original_shape).astype(bool)
324
+ )
325
+
326
+ return unpacked_bitmasks_torch
.venv/lib/python3.11/site-packages/compressed_tensors/utils/offload.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Utilities associated with offloading functionality provided by `accelerate`.
16
+
17
+ | ----------------------------------------------------------------------------------------------------- | # noqa: E501
18
+ | Operation | Without offloading support | With offloading support | # noqa: E501
19
+ | --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
20
+ | Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
21
+ | Check | N/A | has_offloaded_params(module) | # noqa: E501
22
+ | Onload | N/A | with align_module_device(module) | # noqa: E501
23
+ | Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
24
+ | Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
25
+ | ----------------------------------------------------------------------------------------------------- | # noqa: E501
26
+ """
27
+
28
+ import contextlib
29
+ from functools import wraps
30
+ from typing import Any, Callable, Dict, Literal, Optional, Union
31
+
32
+ import torch
33
+
34
+
35
+ try:
36
+ from accelerate.hooks import (
37
+ AlignDevicesHook,
38
+ add_hook_to_module,
39
+ remove_hook_from_module,
40
+ )
41
+ from accelerate.utils import (
42
+ OffloadedWeightsLoader,
43
+ PrefixedDataset,
44
+ set_module_tensor_to_device,
45
+ )
46
+
47
+ _has_accelerate = True
48
+ except ImportError:
49
+ _has_accelerate = False
50
+ AlignDevicesHook = None
51
+ add_hook_to_module = None
52
+ remove_hook_from_module = None
53
+ OffloadedWeightsLoader = None
54
+ PrefixedDataset = None
55
+ set_module_tensor_to_device = None
56
+
57
+
58
+ __all__ = [
59
+ "is_module_offloaded",
60
+ "get_execution_device",
61
+ "get_offloaded_device",
62
+ "update_prefix_dict",
63
+ "update_parameter_data",
64
+ "register_offload_parameter",
65
+ "update_offload_parameter",
66
+ "delete_offload_parameter",
67
+ "has_offloaded_params",
68
+ "disable_hf_hook",
69
+ "align_module_device",
70
+ ]
71
+
72
+
73
+ def check_accelerate(fallback: Any):
74
+ def decorator(func: Callable[[Any], Any]):
75
+ if not _has_accelerate:
76
+
77
+ @wraps(func)
78
+ def fallback_fn(*args, **kwargs):
79
+ return fallback
80
+
81
+ return fallback_fn
82
+
83
+ return func
84
+
85
+ return decorator
86
+
87
+
88
+ """ Candidates for Depreciation """
89
+
90
+
91
+ @check_accelerate(fallback=False)
92
+ def is_module_offloaded(module: torch.nn.Module) -> bool:
93
+ return has_offloaded_params(module)
94
+
95
+
96
+ def get_execution_device(module: torch.nn.Module) -> torch.device:
97
+ """
98
+ :param module: module to check
99
+ :return: device module is loaded onto during forward pass
100
+ """
101
+ if has_offloaded_params(module):
102
+ return module._hf_hook.execution_device
103
+ device = next(module.parameters()).device
104
+
105
+ # offload only gets set for leaf modules, fallback to checking for device type
106
+ if device.type == "meta":
107
+ return module._hf_hook.execution_device
108
+
109
+ return device
110
+
111
+
112
+ def get_offloaded_device(module: torch.nn.Module) -> torch.device:
113
+ """
114
+ :param module: module to check
115
+ :return: device module is offloaded to onto after forward pass
116
+ """
117
+ if has_offloaded_params(module):
118
+ first_key = list(module._hf_hook.weights_map.keys())[0]
119
+ prefix_dataset = module._hf_hook.weights_map.dataset
120
+ return prefix_dataset[first_key].device
121
+ return next(module.parameters()).device
122
+
123
+
124
+ @check_accelerate(fallback=None)
125
+ def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor):
126
+ """
127
+ Updates the offloaded state dict for a given module. Parameter named key is replaced
128
+ by data. This is neccesary because parameter updates for offloaded modules do not
129
+ persist automatically between loads. This function only affects the offloaded
130
+ state dict and not the current state of the loaded module.
131
+
132
+ :param module: module containing the parameter to update
133
+ :param key: name of parameter to update
134
+ :param data: tensor to update parameter with in the offloaded state dict
135
+ """
136
+ if not has_offloaded_params(module):
137
+ raise ValueError("Prefix dict is only applicable to offloaded modules")
138
+
139
+ weights_map = module._hf_hook.weights_map
140
+ offload_to_weights_map(weights_map, key, data)
141
+
142
+
143
+ def update_parameter_data(
144
+ module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str
145
+ ):
146
+ """
147
+ Update the data of an existing parameter and its offload dict. Supports both
148
+ parameters of offloaded modules and non-offloaded modules
149
+
150
+ :param module: module containing the parameter to update
151
+ :param new_param_data: tensor to update parameter with
152
+ :param param_name: name of module parameter to update
153
+ """
154
+ update_offload_parameter(module, param_name, new_param_data)
155
+
156
+
157
+ """ Candidates for Upstreaming """
158
+
159
+
160
+ def register_offload_parameter(
161
+ module: torch.nn.Module,
162
+ name: str,
163
+ parameter: torch.nn.Parameter,
164
+ offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
165
+ ):
166
+ """
167
+ Register a parameter to the given module which may be offloaded
168
+
169
+ :param module: maybe offloaded module
170
+ :param name: name of newly registered parameter
171
+ :param parameter: parameter being registered
172
+ :param offload_device: device on which weight will be offloaded to. If None is
173
+ provided, then infer device from parameters on module
174
+ """
175
+ has_onload = any(p.device != torch.device("meta") for p in module.parameters())
176
+ module.register_parameter(name, parameter)
177
+
178
+ if has_offloaded_params(module):
179
+ weights_map = module._hf_hook.weights_map
180
+ offload_to_weights_map(weights_map, name, parameter.data, offload_device)
181
+ if not has_onload:
182
+ set_module_tensor_to_device(module, name, "meta")
183
+
184
+
185
+ def update_offload_parameter(
186
+ module: torch.nn.Module,
187
+ name: str,
188
+ data: Optional[torch.Tensor],
189
+ offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
190
+ ):
191
+ """
192
+ Update the data of an existing parameter and its offload dict. Supports both
193
+ parameters of offloaded modules and non-offloaded modules
194
+
195
+ :param module: module containing the parameter to update
196
+ :param name: name of module parameter to update
197
+ :param data: tensor to update parameter with
198
+ :param offload_device: device on which weight will be offloaded to. If None is
199
+ provided, then infer device from parameters on module
200
+ """
201
+ param = getattr(module, name)
202
+ data = data.to(param.dtype)
203
+
204
+ # copy data into onloaded parameter if applicable
205
+ if param.device != "meta":
206
+ param.data.copy_(data)
207
+
208
+ # update offload dict
209
+ if has_offloaded_params(module):
210
+ weights_map = module._hf_hook.weights_map
211
+ offload_to_weights_map(weights_map, name, data, offload_device)
212
+
213
+
214
+ def delete_offload_parameter(module: torch.nn.Module, name: str):
215
+ """
216
+ Delete a parameter from a module which may be offloaded
217
+
218
+ :param module: maybe offloaded module
219
+ :param name: name of parameter being deleted
220
+ """
221
+ delattr(module, name)
222
+
223
+ if has_offloaded_params(module):
224
+ weights_map = module._hf_hook.weights_map
225
+ delete_from_weights_map(weights_map, name)
226
+
227
+
228
+ @check_accelerate(fallback=contextlib.nullcontext())
229
+ @contextlib.contextmanager
230
+ def disable_hf_hook(module: torch.nn.Module):
231
+ hooks = {}
232
+
233
+ def collect_hooks(module):
234
+ nonlocal hooks
235
+ if hasattr(module, "_hf_hook"):
236
+ hooks[module] = module._hf_hook
237
+ remove_hook_from_module(module)
238
+
239
+ module.apply(collect_hooks)
240
+
241
+ yield
242
+
243
+ for submodule, hook in hooks.items():
244
+ add_hook_to_module(submodule, hook)
245
+
246
+
247
+ @check_accelerate(fallback=None)
248
+ def offload_to_weights_map(
249
+ weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
250
+ key: str,
251
+ value: torch.Tensor,
252
+ offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
253
+ ):
254
+ """
255
+ Helper function which implements offloaded item assignment for PrefixedDataset,
256
+ OffloadedWeightsLoader, and Dict types.
257
+
258
+ :param weights_map: weight map to be updated with offload information
259
+ :param key: key used to identify weight location
260
+ :param value: weight being offloaded
261
+ :param offload_device: device on which weight will be offloaded to. If None is
262
+ provided, then infer device from parameters in weights_map
263
+ """
264
+ if isinstance(weights_map, PrefixedDataset):
265
+ if offload_device == "disk":
266
+ raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
267
+
268
+ dataset = weights_map.dataset
269
+ key = f"{weights_map.prefix}{key}"
270
+ offload_to_weights_map(dataset, key, value, offload_device)
271
+
272
+ elif isinstance(weights_map, OffloadedWeightsLoader):
273
+ if key not in weights_map.all_keys:
274
+ weights_map.all_keys.append(key)
275
+
276
+ if len(weights_map.index) <= 0 and offload_device != "disk":
277
+ offload_to_weights_map(weights_map.state_dict, key, value, offload_device)
278
+
279
+ else:
280
+ raise NotImplementedError(
281
+ "Updating weights_map with disk offloading is not implemented yet"
282
+ )
283
+
284
+ elif isinstance(weights_map, dict):
285
+ if offload_device == "disk":
286
+ raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
287
+
288
+ # infer offload device
289
+ if offload_device is None:
290
+ if key in weights_map:
291
+ offload_device = weights_map[key].device
292
+ else:
293
+ tens = next(iter(weights_map.values()), None)
294
+ if tens is None:
295
+ raise ValueError(
296
+ "Cannot infer offload device from empty weights_map"
297
+ )
298
+ offload_device = tens.device
299
+
300
+ weights_map[key] = value.to(device=offload_device)
301
+
302
+ else:
303
+ raise NotImplementedError(
304
+ "Updating offload data not implemented for weights_map of type "
305
+ f"{type(weights_map)}"
306
+ )
307
+
308
+
309
+ @check_accelerate(fallback=None)
310
+ def delete_from_weights_map(
311
+ weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
312
+ key: str,
313
+ ):
314
+ if isinstance(weights_map, PrefixedDataset):
315
+ dataset = weights_map.dataset
316
+ key = f"{weights_map.prefix}{key}"
317
+ delete_from_weights_map(dataset, key)
318
+
319
+ elif isinstance(weights_map, OffloadedWeightsLoader):
320
+ if len(weights_map.index) <= 0:
321
+ delete_from_weights_map(weights_map.state_dict, key)
322
+
323
+ else:
324
+ raise NotImplementedError(
325
+ "Delete from weights_map with disk offloading is not implemented yet"
326
+ )
327
+
328
+ elif isinstance(weights_map, dict):
329
+ del weights_map[key]
330
+
331
+ else:
332
+ raise NotImplementedError(
333
+ "Updating offload data not implemented for weights_map of type "
334
+ f"{type(weights_map)}"
335
+ )
336
+
337
+
338
+ """ Upstreamed Functions """
339
+
340
+
341
+ # introduced in accelerate v1.1.0
342
+ @check_accelerate(fallback=False)
343
+ def has_offloaded_params(module: torch.nn.Module) -> bool:
344
+ """
345
+ Checks if a module has offloaded parameters by checking if the given module has a
346
+ AlignDevicesHook attached with offloading enabled
347
+
348
+ Args:
349
+ module (`torch.nn.Module`): The module to check for an offload hook.
350
+
351
+ Returns:
352
+ bool: `True` if the module has an offload hook and offloading is enabled,
353
+ `False` otherwise.
354
+ """
355
+ return (
356
+ hasattr(module, "_hf_hook")
357
+ and isinstance(module._hf_hook, AlignDevicesHook)
358
+ and module._hf_hook.offload
359
+ )
360
+
361
+
362
+ # introduced in accelerate v1.1.0
363
+ @check_accelerate(fallback=contextlib.nullcontext())
364
+ @contextlib.contextmanager
365
+ def align_module_device(
366
+ module: torch.nn.Module, execution_device: Optional[torch.device] = None
367
+ ):
368
+ """
369
+ Context manager that moves a module's parameters to the specified execution device.
370
+
371
+ Args:
372
+ module (`torch.nn.Module`):
373
+ Module with parameters to align.
374
+ execution_device (`torch.device`, *optional*):
375
+ If provided, overrides the module's execution device within the context.
376
+ Otherwise, use hook execution device or pass
377
+ """
378
+ if has_offloaded_params(module):
379
+ if execution_device is not None:
380
+ original_device = module._hf_hook.execution_device
381
+ module._hf_hook.execution_device = execution_device
382
+
383
+ try:
384
+ module._hf_hook.pre_forward(module)
385
+ yield
386
+ finally:
387
+ module._hf_hook.post_forward(module, None)
388
+ if execution_device is not None:
389
+ module._hf_hook.execution_device = original_device
390
+
391
+ elif execution_device is not None:
392
+ devices = {
393
+ name: param.device for name, param in module.named_parameters(recurse=False)
394
+ }
395
+ try:
396
+ for name in devices:
397
+ set_module_tensor_to_device(module, name, execution_device)
398
+ yield
399
+ finally:
400
+ for name, device in devices.items():
401
+ set_module_tensor_to_device(module, name, device)
402
+
403
+ else:
404
+ yield
.venv/lib/python3.11/site-packages/compressed_tensors/utils/permutations_24.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy
17
+ import torch
18
+
19
+
20
+ __all__ = ["get_permutations_24"]
21
+
22
+
23
+ # Precompute permutations for Marlin24 weight and scale shuffling
24
+ # Originally implemented in nm-vllm/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py # noqa: E501
25
+ #
26
+ # Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight
27
+ # data so that it is compatible with the tensor-core format that is described here:
28
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
29
+ #
30
+ # As a result of this reordering, the vector loads inside the kernel will get the data
31
+ # as it is needed for tensor-core (without the need to use ldmatrix instructions)
32
+ def get_permutations_24(num_bits):
33
+ perm_list = []
34
+ for i in range(32):
35
+ perm1 = []
36
+ col = i // 4
37
+ col_o = col // 2
38
+ for block in [0, 1]:
39
+ for row in [
40
+ 2 * (i % 4),
41
+ 2 * (i % 4) + 1,
42
+ 2 * (i % 4 + 4),
43
+ 2 * (i % 4 + 4) + 1,
44
+ ]:
45
+ perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
46
+ for j in range(4):
47
+ perm_list.extend([p + 1 * j for p in perm1])
48
+ perm = numpy.array(perm_list)
49
+
50
+ if num_bits == 4:
51
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
52
+ elif num_bits == 8:
53
+ interleave = numpy.array([0, 2, 1, 3])
54
+ else:
55
+ raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
56
+
57
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
58
+ perm = torch.from_numpy(perm)
59
+ scale_perm = []
60
+ for i in range(8):
61
+ scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
62
+ scale_perm_single = []
63
+ for i in range(8):
64
+ scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
65
+ return perm, scale_perm, scale_perm_single
.venv/lib/python3.11/site-packages/compressed_tensors/utils/permute.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Set, Tuple
16
+
17
+ import torch
18
+
19
+
20
+ __all__ = ["safe_permute"]
21
+
22
+
23
+ # these datatypes are missing implementations required for standard permutation
24
+ _EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set()
25
+
26
+
27
+ def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor:
28
+ """
29
+ Perform out-of-place permutation without using torch.Tensor.index_put_,
30
+ whose implementation is missing for datatypes such as `torch.float8_e4m3fn`
31
+
32
+ :param value: tensor to permute
33
+ :param perm: permutation map
34
+ :param dim: dimension along which to apply permutation
35
+ :return: permuted value
36
+ """
37
+ dtype_tuple = (value.dtype, value.device)
38
+
39
+ if dtype_tuple in _EXPERIMENTAL_DTYPES:
40
+ return _fallback_permute(value, perm, dim)
41
+
42
+ try:
43
+ return value[tuple([slice(None)] * dim + [perm])]
44
+ except RuntimeError:
45
+ # Mark dtype as experimental if advanced indexing fails
46
+ _EXPERIMENTAL_DTYPES.add(dtype_tuple)
47
+ return _fallback_permute(value, perm, dim)
48
+
49
+
50
+ def _fallback_permute(
51
+ value: torch.Tensor, perm: torch.Tensor, dim: int
52
+ ) -> torch.Tensor:
53
+ """
54
+ Fallback permutation method for experimental dtypes.
55
+
56
+ :param value: tensor to permute
57
+ :param perm: permutation map
58
+ :param dim: dimension along which to apply permutation
59
+ :return: permuted value
60
+ """
61
+ value_ret = value.clone() # cannot use zeros_like b/c of missing impl.
62
+ orig_slices = [slice(None)] * (dim + 1)
63
+ perm_slices = [slice(None)] * (dim + 1)
64
+
65
+ for index, perm_index in enumerate(perm):
66
+ orig_slices[dim] = index
67
+ perm_slices[dim] = perm_index
68
+ value_ret[tuple(orig_slices)] = value[tuple(perm_slices)]
69
+
70
+ return value_ret
.venv/lib/python3.11/site-packages/compressed_tensors/utils/safetensors_load.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ import re
18
+ import struct
19
+ from typing import Dict, List, Optional, Tuple, Union
20
+
21
+ from safetensors import safe_open
22
+ from torch import Tensor
23
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file
24
+
25
+
26
+ __all__ = [
27
+ "get_safetensors_folder",
28
+ "get_safetensors_header",
29
+ "match_param_name",
30
+ "merge_names",
31
+ "get_weight_mappings",
32
+ "get_nested_weight_mappings",
33
+ "get_nested_mappings_from_state_dict",
34
+ "get_quantization_state_dict",
35
+ "is_quantization_param",
36
+ ]
37
+
38
+ WeightMappingType = Dict[str, str]
39
+ NestedWeightMappingType = Dict[str, WeightMappingType]
40
+
41
+
42
+ def get_safetensors_folder(
43
+ pretrained_model_name_or_path: str, cache_dir: Optional[str] = None
44
+ ) -> str:
45
+ """
46
+ Given a Hugging Face stub or a local path, return the folder containing the
47
+ safetensors weight files
48
+
49
+ :param pretrained_model_name_or_path: local path to model or HF stub
50
+ :param cache_dir: optional cache dir to search through, if none is specified the
51
+ model will be searched for in the default TRANSFORMERS_CACHE
52
+ :return: local folder containing model data
53
+ """
54
+ if os.path.exists(pretrained_model_name_or_path):
55
+ # argument is a path to a local folder
56
+ return os.path.abspath(pretrained_model_name_or_path)
57
+
58
+ safetensors_path = cached_file(
59
+ pretrained_model_name_or_path,
60
+ SAFE_WEIGHTS_NAME,
61
+ cache_dir=cache_dir,
62
+ _raise_exceptions_for_missing_entries=False,
63
+ )
64
+ index_path = cached_file(
65
+ pretrained_model_name_or_path,
66
+ SAFE_WEIGHTS_INDEX_NAME,
67
+ cache_dir=cache_dir,
68
+ _raise_exceptions_for_missing_entries=False,
69
+ )
70
+ if safetensors_path is not None:
71
+ # found a single cached safetensors file
72
+ return os.path.split(safetensors_path)[0]
73
+ if index_path is not None:
74
+ # found a cached safetensors weight index file
75
+ return os.path.split(index_path)[0]
76
+
77
+ # model weights could not be found locally or cached from HF Hub
78
+ raise ValueError(
79
+ "Could not locate safetensors weight or index file from "
80
+ f"{pretrained_model_name_or_path}."
81
+ )
82
+
83
+
84
+ def get_safetensors_header(safetensors_path: str) -> Dict[str, str]:
85
+ """
86
+ Extracts the metadata from a safetensors file as JSON
87
+
88
+ :param safetensors_path: path to a safetensors file
89
+ :return: dictionary of metadata extracted from the safetensors file
90
+ """
91
+ with open(safetensors_path, "rb") as f:
92
+ length_of_header = struct.unpack("<Q", f.read(8))[0]
93
+ header_data = f.read(length_of_header)
94
+ header = json.loads(header_data)
95
+
96
+ return header
97
+
98
+
99
+ def match_param_name(full_name: str, param_name: str) -> Optional[str]:
100
+ """
101
+ Helper function extracting the uncompressed parameterized layer name from a
102
+ compressed name. Assumes the compressed name was merged using merge_names.
103
+
104
+ :param full_name: full name of parameter in compressed model
105
+ :param param_name: compression paramater name
106
+ :return: uncompressed name of the uncompressed parameterized layer
107
+ """
108
+ pattern = r"^(.*)\." + param_name + r"$"
109
+ regex = re.findall(pattern, full_name)
110
+ if len(regex) == 0:
111
+ return None
112
+ return regex[0]
113
+
114
+
115
+ def merge_names(parent_name: str, child_name: str) -> str:
116
+ """
117
+ Helper function for merging an uncompressed parameterized layer name with a
118
+ compression parameter. Names merged with this function can then be parsed by
119
+ match_param_name.
120
+
121
+ :param parent_name: uncompressed parameterized layer name
122
+ :param child_name: compression parameter name
123
+ :return: merged compressed name
124
+ """
125
+ return parent_name + "." + child_name
126
+
127
+
128
+ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:
129
+ """
130
+ Takes a path to a state dict saved in safetensors format and returns a mapping
131
+ from parameterized layer name to file location.
132
+
133
+ {
134
+ layer.weight.bitmask: file_location,
135
+ layer.weight.row_offsets: file_location,
136
+ layer.weight.shape: file_location,
137
+ layer.weight.compressed: file_location
138
+ }
139
+
140
+ This generalizes to cases where the model is split into multiple safetensors files
141
+
142
+ :param path_to_model_or_tensors: path to directory that contains
143
+ safetensors (must contain either a single file or multiple files with an index),
144
+ or a path to a single safetensors file
145
+ :return: mapping of parameterized layer name to file location
146
+ """
147
+
148
+ if os.path.isfile(path_to_model_or_tensors):
149
+ # we have a single safetensors file to read
150
+ header = get_safetensors_header(path_to_model_or_tensors)
151
+ for key in header.keys():
152
+ header[key] = path_to_model_or_tensors
153
+ header.pop("__metadata__", None)
154
+ else:
155
+ # we have a directory with multiple safetensors files
156
+ safetensors_path = os.path.join(path_to_model_or_tensors, SAFE_WEIGHTS_NAME)
157
+ index_path = os.path.join(path_to_model_or_tensors, SAFE_WEIGHTS_INDEX_NAME)
158
+ if os.path.exists(safetensors_path):
159
+ # we have a single safetensors file to read
160
+ header = get_safetensors_header(safetensors_path)
161
+ for key in header.keys():
162
+ header[key] = SAFE_WEIGHTS_NAME
163
+ header.pop("__metadata__", None)
164
+ elif os.path.exists(index_path):
165
+ # we have multiple safetensors file, read from index
166
+ with open(index_path, "r", encoding="utf-8") as f:
167
+ index = json.load(f)
168
+ header = index["weight_map"]
169
+ else:
170
+ raise ValueError(
171
+ "Could not find a safetensors weight "
172
+ f"or index file at {path_to_model_or_tensors}"
173
+ )
174
+
175
+ # convert weight locations to full paths
176
+ for key, value in header.items():
177
+ header[key] = os.path.join(path_to_model_or_tensors, value)
178
+
179
+ return header
180
+
181
+
182
+ def get_nested_weight_mappings(
183
+ model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False
184
+ ) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]:
185
+ """
186
+ Takes a path to a state dict saved in safetensors format and returns a nested
187
+ mapping from uncompressed parameterized layer names to the file locations of
188
+ each layer's compression parameters.
189
+
190
+ Example of the nested mapping:
191
+ layer: {
192
+ bitmask: file_location,
193
+ row_offsets: file_location,
194
+ shape: file_location,
195
+ compressed: file_location
196
+ }
197
+
198
+ If other parameters are found that do not match the nested parameters, they will
199
+ be returned in a separate dictionary only if return_unmatched_params is True.
200
+ This dictionary may be needed for cases where compressors are stacked (e.g.,
201
+ quantization compression followed by sparse compression).
202
+
203
+ Example of the unmatched params mapping:
204
+ {
205
+ layer.weight_scale: file_location,
206
+ layer.input_scale: file_location
207
+ }
208
+
209
+ This generalizes to cases where the model is split into multiple safetensors
210
+ files.
211
+
212
+ :param model_path: Path to the safetensors state dict, must contain either a
213
+ single safetensors file or multiple files with an index.
214
+ :param params_to_nest: List of parameter names to nest.
215
+ :param return_unmatched_params: If True, return a second dictionary containing
216
+ the remaining parameters that were not matched to the params_to_nest.
217
+ :return:
218
+ - If return_unmatched_params is False:
219
+ NestedWeightMappingType: A nested mapping of parameterized layer names to
220
+ file locations of each layer's compression parameters.
221
+ - If return_unmatched_params is True:
222
+ Tuple[NestedWeightMappingType, WeightMappingType]: A tuple containing:
223
+ - NestedWeightMappingType: A nested mapping of parameterized layer
224
+ names to file locations of each layer's compression parameters.
225
+ - WeightMappingType: A mapping of the remaining parameter names to
226
+ their file locations that were not matched to the params_to_nest.
227
+ """
228
+ weight_mappings = get_weight_mappings(model_path)
229
+ nested_weight_mappings = {}
230
+ unmatched_params = {}
231
+
232
+ for key, file_location in weight_mappings.items():
233
+ matched = False
234
+ for param_name in params_to_nest:
235
+ dense_param = match_param_name(key, param_name)
236
+ if dense_param:
237
+ if dense_param not in nested_weight_mappings:
238
+ nested_weight_mappings[dense_param] = {}
239
+ nested_weight_mappings[dense_param][param_name] = file_location
240
+ matched = True
241
+ if return_unmatched_params and not matched:
242
+ unmatched_params[key] = file_location
243
+
244
+ if return_unmatched_params:
245
+ return nested_weight_mappings, unmatched_params
246
+ return nested_weight_mappings
247
+
248
+
249
+ def get_nested_mappings_from_state_dict(
250
+ state_dict, params_to_nest
251
+ ) -> NestedWeightMappingType:
252
+ """
253
+ Takes a state dict and returns a nested mapping from uncompressed
254
+ parameterized layer names to the value of
255
+ each layer's compression parameters.
256
+
257
+ Example of the nested mapping:
258
+ layer: {
259
+ weight_scale: ...,
260
+ weight: ...,
261
+ zero_point: ...,
262
+ }
263
+
264
+ :param state_dict: state dict of the model
265
+ :param params_to_nest: List of parameter names to nest.
266
+ :return: Nested mapping of parameterized layer names to the value of
267
+ each layer's compression parameters.
268
+ """
269
+ nested_weight_mappings = {}
270
+ for key in state_dict.keys():
271
+ for param_name in params_to_nest:
272
+ dense_param = match_param_name(key, param_name)
273
+ if dense_param:
274
+ if dense_param not in nested_weight_mappings:
275
+ nested_weight_mappings[dense_param] = {}
276
+ nested_weight_mappings[dense_param][param_name] = state_dict[key]
277
+ return nested_weight_mappings
278
+
279
+
280
+ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
281
+ weight_mappings = get_weight_mappings(model_path)
282
+ state_dict = {}
283
+ for weight_name, safe_path in weight_mappings.items():
284
+ if not is_quantization_param(weight_name):
285
+ continue
286
+ with safe_open(safe_path, framework="pt", device="cpu") as f:
287
+ state_dict[weight_name] = f.get_tensor(weight_name)
288
+
289
+ return state_dict
290
+
291
+
292
+ def is_quantization_param(name: str) -> bool:
293
+ """
294
+ Checks is a parameter name is associated with a quantization parameter
295
+
296
+ :param name: parameter name to check
297
+ :return: True if parameter name is a quantization parameter, else False
298
+ """
299
+ if name.endswith("_scale"):
300
+ return True
301
+ if name.endswith("zero_point"):
302
+ return True
303
+ if name.endswith("g_idx"):
304
+ return True
305
+
306
+ return False
.venv/lib/python3.11/site-packages/compressed_tensors/utils/semi_structured_conversions.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
3
+ # Pulled from nm-vllm/vllm/model_executor/layers/quantization/utils/format_24.py
4
+ #
5
+ # flake8: noqa
6
+ # isort: skip_file
7
+
8
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing,
17
+ # software distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ import torch
23
+
24
+
25
+ __all__ = [
26
+ "sparse_semi_structured_from_dense_cutlass",
27
+ "sparse_semi_structured_to_dense_cutlass",
28
+ "mask_creator",
29
+ ]
30
+
31
+
32
+ # This is PyTorch implementation of main part of reorder_meta()
33
+ # function, from tools/util/include/cutlass/util/host_reorder.h file
34
+ # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
35
+ # GEMM decides upon layout of this matrix, and at the moment for the
36
+ # sparse GEMM executed on tensor cores, this is layout described by
37
+ # ColumnMajorInterleaved<2> data structure, in
38
+ # include/cutlass/layout/matrix.h of CUTLASS source tree. The
39
+ # reordering of meta matrix into meta_reordered matrix calculated
40
+ # according to these segments of CUTLASS code is re-implemented here.
41
+ # Note that this calculation produces offsets for scattering metadata
42
+ # matrix elements into reordered metadata matrix elements (or,
43
+ # equivalently, for gathering reordered metadata matrix element back
44
+ # into metadata matrix elements).
45
+ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
46
+ dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
47
+ dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
48
+
49
+ # Reorder the rows, then swizzle the 2x2 blocks.
50
+ group_x = 64
51
+ group_y = 32 if meta_dtype.itemsize == 2 else 16
52
+
53
+ dst_rows = (
54
+ dst_rows // group_x * group_x
55
+ + (dst_rows % 2) * 2
56
+ + (dst_rows % 8) // 4
57
+ + ((dst_rows % group_y) % 4) // 2 * 32
58
+ + ((dst_rows % group_x) // 8) * 4
59
+ )
60
+
61
+ topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
62
+ bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
63
+ dst_rows += topright - bottomleft
64
+ dst_cols -= topright - bottomleft
65
+
66
+ # Assumed that meta tensor is to be stored in CUTLASS
67
+ # InterleavedColumnMajor layout, and reverse engineered
68
+ # corresponding code to store values into this tensor.
69
+ interleave = 2
70
+ cols_maj = dst_cols // interleave
71
+ cols_min = dst_cols % interleave
72
+ return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
73
+
74
+
75
+ # This function converts dense matrix into sparse semi-structured
76
+ # representation, producing "compressed" matrix, in the layout used by
77
+ # CUTLASS backend, and corresponding metadata matrix.
78
+ def sparse_semi_structured_from_dense_cutlass(dense):
79
+ if dense.dim() != 2:
80
+ raise RuntimeError(
81
+ f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
82
+ )
83
+
84
+ m, k = dense.shape
85
+ device = dense.device
86
+
87
+ meta_dtype = torch.int8
88
+ if dense.dtype == torch.int8:
89
+ meta_dtype = torch.int32
90
+ elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
91
+ meta_dtype = torch.int16
92
+ else:
93
+ raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
94
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
95
+ if quadbits_per_meta_elem not in (4, 8):
96
+ raise RuntimeError("Invalid number of elements per meta element calculated")
97
+
98
+ if meta_dtype == torch.int32:
99
+ if m % 16 != 0:
100
+ raise RuntimeError(
101
+ f"Number of rows of dense matrix {m} must be divisible by 16"
102
+ )
103
+ else:
104
+ if m % 32 != 0:
105
+ raise RuntimeError(
106
+ f"Number of rows of dense matrix {m} must be divisible by 32"
107
+ )
108
+ if k % (4 * quadbits_per_meta_elem) != 0:
109
+ raise RuntimeError(
110
+ f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
111
+ )
112
+
113
+ if dense.dtype != torch.float:
114
+ ksparse = 4
115
+ dense_4 = dense.view(-1, k // ksparse, ksparse)
116
+ m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
117
+ else:
118
+ ksparse = 2
119
+ dense_2 = dense.view(-1, k // ksparse, ksparse)
120
+ m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
121
+ meta_ncols = k // (ksparse * quadbits_per_meta_elem)
122
+
123
+ # Encoding quadruples of True/False values as follows:
124
+ # [True, True, False, False] -> 0b0100
125
+ # [True, False, True, False] -> 0b1000
126
+ # [False, True, True, False] -> 0b1001
127
+ # [True, False, False, True ] -> 0b1100
128
+ # [False, True, False, True ] -> 0b1101
129
+ # [False, False, True, True ] -> 0b1110
130
+ # Thus, lower two bits in the encoding are index of the True value
131
+ # at the lowest index in the quadruple, and the higher two bits in
132
+ # the encoding are index of the other True value in the quadruple.
133
+ # In case there are less than two True values, than False value or
134
+ # values at some index or indices are considered True for the
135
+ # encoding. In case there are more than two True values, then the
136
+ # excess True value(s) at some indices are considered False for
137
+ # the encoding. The exact encodings used for these cases are as
138
+ # follows:
139
+ # [False, False, False, False] -> 0b1110
140
+ # [False, False, False, True ] -> 0b1110
141
+ # [False, False, True, False] -> 0b1110
142
+ # [False, True, False, False] -> 0b1001
143
+ # [False, True, True, True ] -> 0b1101
144
+ # [True, False, False, False] -> 0b1000
145
+ # [True, False, True, True ] -> 0b1100
146
+ # [True, True, False, True ] -> 0b0100
147
+ # [True, True, True, False] -> 0b0100
148
+ # [True, True, True, True ] -> 0b0100
149
+ # These particular encodings are chosen, with the help of Espresso
150
+ # logic minimizer software, for the purpose of minimization of
151
+ # corresponding Boolean functions, that translate non-zero flags
152
+ # into encoding bits. Note also possible choices for the first
153
+ # and last of these encodings were limited only to (0b0100,
154
+ # 0b1110), in order to produce valid encodings for 1:2 sparsity
155
+ # case.
156
+
157
+ expr0 = m0 & m1
158
+ expr1 = ~m0 & m1
159
+ expr2 = ~m0 & ~m1
160
+ bit0 = expr1
161
+ bit1 = expr2
162
+ bit2 = expr0 | expr2 | m3
163
+ bit3 = expr1 | ~m1
164
+ idxs0 = bit0 | (bit1.to(torch.int64) << 1)
165
+ idxs1 = bit2 | (bit3.to(torch.int64) << 1)
166
+
167
+ if dense.dtype != torch.float:
168
+ sparse0 = dense_4.gather(
169
+ -1, idxs0.unsqueeze(-1)
170
+ ) # type: ignore[possibly-undefined]
171
+ sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
172
+ sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
173
+ else:
174
+ sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
175
+ m, k // 2
176
+ ) # type: ignore[possibly-undefined]
177
+
178
+ meta_4 = idxs0 | (idxs1 << 2)
179
+ meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
180
+
181
+ if quadbits_per_meta_elem == 4:
182
+ meta = (
183
+ meta_n[:, :, 0]
184
+ | (meta_n[:, :, 1] << 4)
185
+ | (meta_n[:, :, 2] << 8)
186
+ | (meta_n[:, :, 3] << 12)
187
+ )
188
+ elif quadbits_per_meta_elem == 8:
189
+ meta = (
190
+ meta_n[:, :, 0]
191
+ | (meta_n[:, :, 1] << 4)
192
+ | (meta_n[:, :, 2] << 8)
193
+ | (meta_n[:, :, 3] << 12)
194
+ | (meta_n[:, :, 4] << 16)
195
+ | (meta_n[:, :, 5] << 20)
196
+ | (meta_n[:, :, 6] << 24)
197
+ | (meta_n[:, :, 7] << 28)
198
+ )
199
+
200
+ # Reorder meta tensor elements.
201
+ meta_reordered = meta.new_empty(
202
+ (m * meta_ncols,)
203
+ ) # type: ignore[possibly-undefined]
204
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
205
+ m, meta_ncols, meta_dtype, device
206
+ )
207
+ meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
208
+
209
+ return (sparse, meta_reordered.view(m, meta_ncols))
210
+
211
+
212
+ # This function performs reverse of the function above - it
213
+ # reconstructs dense matrix from a pair of "compressed" matrix, given
214
+ # in the layout used by CUTLASS backend, and accompanying metadata
215
+ # matrix.
216
+ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
217
+ if sparse.dim() != 2:
218
+ raise RuntimeError(
219
+ f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
220
+ )
221
+
222
+ m, k = sparse.shape
223
+ device = sparse.device
224
+
225
+ if meta_reordered.dim() != 2:
226
+ raise RuntimeError(
227
+ f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
228
+ )
229
+ if meta_reordered.device != device:
230
+ raise RuntimeError(
231
+ f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
232
+ )
233
+
234
+ meta_dtype = meta_reordered.dtype
235
+ if meta_dtype not in (torch.int16, torch.int32):
236
+ raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
237
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
238
+
239
+ ksparse = 4 if sparse.dtype != torch.float else 2
240
+
241
+ meta_nrows, meta_ncols = meta_reordered.shape
242
+ if meta_nrows != m:
243
+ raise RuntimeError(
244
+ f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
245
+ )
246
+ if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
247
+ raise RuntimeError(
248
+ f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
249
+ "expected according to the number of columns of meta matrix"
250
+ )
251
+
252
+ # Undo meta tensor elements reordering.
253
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
254
+ m, meta_ncols, meta_dtype, device
255
+ )
256
+ meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
257
+
258
+ # Unpack sparse tensor back to original dense tensor, using
259
+ # information provided by meta tensor. Note that torch.float
260
+ # datatype is handled pretty much the same as
261
+ # torch.half/torch.bfloat16, as metadata for a pair of torch.float
262
+ # value is encoded as if underlying 8 bytes contain four
263
+ # torch.half/torch.bfloat16 values, where either first two or last
264
+ # two are zeros.
265
+ meta_2 = torch.empty(
266
+ (m, meta_ncols, 2 * quadbits_per_meta_elem),
267
+ dtype=meta_dtype,
268
+ device=device,
269
+ )
270
+ if quadbits_per_meta_elem == 4:
271
+ meta_2[:, :, 0] = meta & 0b11
272
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
273
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
274
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
275
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
276
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
277
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
278
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
279
+ elif quadbits_per_meta_elem == 8:
280
+ meta_2[:, :, 0] = meta & 0b11
281
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
282
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
283
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
284
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
285
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
286
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
287
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
288
+ meta_2[:, :, 8] = (meta >> 16) & 0b11
289
+ meta_2[:, :, 9] = (meta >> 18) & 0b11
290
+ meta_2[:, :, 10] = (meta >> 20) & 0b11
291
+ meta_2[:, :, 11] = (meta >> 22) & 0b11
292
+ meta_2[:, :, 12] = (meta >> 24) & 0b11
293
+ meta_2[:, :, 13] = (meta >> 26) & 0b11
294
+ meta_2[:, :, 14] = (meta >> 28) & 0b11
295
+ meta_2[:, :, 15] = (meta >> 30) & 0b11
296
+
297
+ dense_offsets = meta_2.view(-1) + (
298
+ torch.arange(0, 2 * m * k // ksparse, device=device) * 4
299
+ ).view(-1, 1).repeat(1, 2).view(-1)
300
+
301
+ dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
302
+ if sparse.dtype != torch.float:
303
+ # dense.scatter_(0, dense_offsets, sparse.view(-1))
304
+ dense.scatter_(0, dense_offsets, sparse.reshape(-1))
305
+ else:
306
+ dense.view(torch.half).scatter_(
307
+ 0, dense_offsets, sparse.view(torch.half).view(-1)
308
+ )
309
+
310
+ return dense.view(m, 2 * k)
311
+
312
+
313
+ def mask_creator(tensor):
314
+ """
315
+ Class for creating N:M sparsity masks.
316
+ Masks will be created using the N:M ratio, where for every block of
317
+ M weights, N will be pruned based on ranked weight value. Each mask
318
+ will correspond to the given tensor.
319
+
320
+ :param N: The number of weights in a group to keep
321
+ :param M: The size of a weight group
322
+ """
323
+ N = 2
324
+ M = 4
325
+
326
+ mask = None
327
+ # for i, tensor in enumerate(tensors):
328
+ if tensor.numel() % M != 0:
329
+ raise ValueError(
330
+ f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
331
+ )
332
+
333
+ num_groups = tensor.numel() // M
334
+
335
+ # N:M sparsity for linear layers
336
+ tensor_temp = tensor.detach().abs().reshape(num_groups, M)
337
+ index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
338
+
339
+ w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
340
+ mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
341
+
342
+ return mask
.venv/lib/python3.11/site-packages/dotenv/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional
2
+
3
+ from .main import (dotenv_values, find_dotenv, get_key, load_dotenv, set_key,
4
+ unset_key)
5
+
6
+
7
+ def load_ipython_extension(ipython: Any) -> None:
8
+ from .ipython import load_ipython_extension
9
+ load_ipython_extension(ipython)
10
+
11
+
12
+ def get_cli_string(
13
+ path: Optional[str] = None,
14
+ action: Optional[str] = None,
15
+ key: Optional[str] = None,
16
+ value: Optional[str] = None,
17
+ quote: Optional[str] = None,
18
+ ):
19
+ """Returns a string suitable for running as a shell script.
20
+
21
+ Useful for converting a arguments passed to a fabric task
22
+ to be passed to a `local` or `run` command.
23
+ """
24
+ command = ['dotenv']
25
+ if quote:
26
+ command.append(f'-q {quote}')
27
+ if path:
28
+ command.append(f'-f {path}')
29
+ if action:
30
+ command.append(action)
31
+ if key:
32
+ command.append(key)
33
+ if value:
34
+ if ' ' in value:
35
+ command.append(f'"{value}"')
36
+ else:
37
+ command.append(value)
38
+
39
+ return ' '.join(command).strip()
40
+
41
+
42
+ __all__ = ['get_cli_string',
43
+ 'load_dotenv',
44
+ 'dotenv_values',
45
+ 'get_key',
46
+ 'set_key',
47
+ 'unset_key',
48
+ 'find_dotenv',
49
+ 'load_ipython_extension']
.venv/lib/python3.11/site-packages/dotenv/__main__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Entry point for cli, enables execution with `python -m dotenv`"""
2
+
3
+ from .cli import cli
4
+
5
+ if __name__ == "__main__":
6
+ cli()
.venv/lib/python3.11/site-packages/dotenv/__pycache__/__main__.cpython-311.pyc ADDED
Binary file (386 Bytes). View file
 
.venv/lib/python3.11/site-packages/dotenv/__pycache__/cli.cpython-311.pyc ADDED
Binary file (10.9 kB). View file
 
.venv/lib/python3.11/site-packages/dotenv/__pycache__/ipython.cpython-311.pyc ADDED
Binary file (2.3 kB). View file
 
.venv/lib/python3.11/site-packages/dotenv/__pycache__/main.cpython-311.pyc ADDED
Binary file (18.1 kB). View file