Brunobkr commited on
Commit
7392e75
·
verified ·
1 Parent(s): 2694729

Upload 4 files

Browse files
Files changed (4) hide show
  1. __init__.py +9 -0
  2. gguf_reader.py +371 -0
  3. gguf_writer.py +1276 -0
  4. quants.py +1443 -0
__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import *
2
+ from .lazy import *
3
+ from .gguf_reader import *
4
+ from .gguf_writer import *
5
+ from .tensor_mapping import *
6
+ from .vocab import *
7
+ from .utility import *
8
+ from .metadata import *
9
+ from gguf.quants import HelicoidalZetaCore # Importação necessária!
gguf_reader.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #BRUNO BECKER / OFFELLIA 2026
2
+ #brunoconta1980@gmail.com
3
+ #brunoconta1980@hotmail.com
4
+ # X @Brunoxuser
5
+
6
+ #
7
+ # GGUF file reading/modification support. For API usage information,
8
+ # please see the files scripts/ for some fairly simple examples.
9
+ #
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ import os
14
+ import sys
15
+ from collections import OrderedDict
16
+ from typing import Any, Literal, NamedTuple, TypeVar, Union
17
+
18
+ import numpy as np
19
+ import numpy.typing as npt
20
+
21
+ from .quants import quant_shape_to_byte_shape
22
+
23
+ if __name__ == "__main__":
24
+ from pathlib import Path
25
+
26
+ # Allow running file in package as a script.
27
+ sys.path.insert(0, str(Path(__file__).parent.parent))
28
+
29
+ from gguf.constants import (
30
+ GGML_QUANT_SIZES,
31
+ GGUF_DEFAULT_ALIGNMENT,
32
+ GGUF_MAGIC,
33
+ GGUF_VERSION,
34
+ GGMLQuantizationType,
35
+ GGUFValueType,
36
+ GGUFEndian,
37
+ )
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
42
+
43
+
44
+ class ReaderField(NamedTuple):
45
+ # Offset to start of this field.
46
+ offset: int
47
+
48
+ # Name of the field (not necessarily from file data).
49
+ name: str
50
+
51
+ # Data parts. Some types have multiple components, such as strings
52
+ # that consist of a length followed by the string data.
53
+ parts: list[npt.NDArray[Any]] = []
54
+
55
+ # Indexes into parts that we can call the actual data. For example
56
+ # an array of strings will be populated with indexes to the actual
57
+ # string data.
58
+ data: list[int] = [-1]
59
+
60
+ types: list[GGUFValueType] = []
61
+
62
+ def contents(self, index_or_slice: int | slice = slice(None)) -> Any:
63
+ if self.types:
64
+ to_string = lambda x: str(x.tobytes(), encoding='utf-8') # noqa: E731
65
+ main_type = self.types[0]
66
+
67
+ if main_type == GGUFValueType.ARRAY:
68
+ sub_type = self.types[-1]
69
+
70
+ if sub_type == GGUFValueType.STRING:
71
+ indices = self.data[index_or_slice]
72
+
73
+ if isinstance(index_or_slice, int):
74
+ return to_string(self.parts[indices]) # type: ignore
75
+ else:
76
+ return [to_string(self.parts[idx]) for idx in indices] # type: ignore
77
+ else:
78
+ # FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too
79
+
80
+ # Check if it's unsafe to perform slice optimization on data
81
+ # if any(True for idx in self.data if len(self.parts[idx]) != 1):
82
+ # optim_slice = slice(None)
83
+ # else:
84
+ # optim_slice = index_or_slice
85
+ # index_or_slice = slice(None)
86
+
87
+ # if isinstance(optim_slice, int):
88
+ # return self.parts[self.data[optim_slice]].tolist()[0]
89
+ # else:
90
+ # return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice]
91
+
92
+ if isinstance(index_or_slice, int):
93
+ return self.parts[self.data[index_or_slice]].tolist()[0]
94
+ else:
95
+ return [pv for idx in self.data[index_or_slice] for pv in self.parts[idx].tolist()]
96
+
97
+ if main_type == GGUFValueType.STRING:
98
+ return to_string(self.parts[-1])
99
+ else:
100
+ return self.parts[-1].tolist()[0]
101
+
102
+ return None
103
+
104
+
105
+ class ReaderTensor(NamedTuple):
106
+ name: str
107
+ tensor_type: GGMLQuantizationType
108
+ shape: npt.NDArray[np.uint32]
109
+ n_elements: int
110
+ n_bytes: int
111
+ data_offset: int
112
+ data: npt.NDArray[Any]
113
+ field: ReaderField
114
+
115
+
116
+ class GGUFReader:
117
+ # I - same as host, S - swapped
118
+ byte_order: Literal['I', 'S'] = 'I'
119
+ alignment: int = GGUF_DEFAULT_ALIGNMENT
120
+ data_offset: int
121
+
122
+ # Note: Internal helper, API may change.
123
+ gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
124
+ GGUFValueType.UINT8: np.uint8,
125
+ GGUFValueType.INT8: np.int8,
126
+ GGUFValueType.UINT16: np.uint16,
127
+ GGUFValueType.INT16: np.int16,
128
+ GGUFValueType.UINT32: np.uint32,
129
+ GGUFValueType.INT32: np.int32,
130
+ GGUFValueType.FLOAT32: np.float32,
131
+ GGUFValueType.UINT64: np.uint64,
132
+ GGUFValueType.INT64: np.int64,
133
+ GGUFValueType.FLOAT64: np.float64,
134
+ GGUFValueType.BOOL: np.bool_,
135
+ }
136
+
137
+ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
138
+ self.data = np.memmap(path, mode = mode)
139
+ offs = 0
140
+
141
+ # Check for GGUF magic
142
+ if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
143
+ raise ValueError('GGUF magic invalid')
144
+ offs += 4
145
+
146
+ # Check GGUF version
147
+ temp_version = self._get(offs, np.uint32)
148
+ if temp_version[0] & 65535 == 0:
149
+ # If we get 0 here that means it's (probably) a GGUF file created for
150
+ # the opposite byte order of the machine this script is running on.
151
+ self.byte_order = 'S'
152
+ temp_version = temp_version.view(temp_version.dtype.newbyteorder(self.byte_order))
153
+ version = temp_version[0]
154
+ if version not in READER_SUPPORTED_VERSIONS:
155
+ raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
156
+ if sys.byteorder == "little":
157
+ # Host is little endian
158
+ host_endian = GGUFEndian.LITTLE
159
+ swapped_endian = GGUFEndian.BIG
160
+ else:
161
+ # Sorry PDP or other weird systems that don't use BE or LE.
162
+ host_endian = GGUFEndian.BIG
163
+ swapped_endian = GGUFEndian.LITTLE
164
+ self.endianess = swapped_endian if self.byte_order == "S" else host_endian
165
+ self.fields: OrderedDict[str, ReaderField] = OrderedDict()
166
+ self.tensors: list[ReaderTensor] = []
167
+ offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
168
+
169
+ # Check tensor count and kv count
170
+ temp_counts = self._get(offs, np.uint64, 2)
171
+ offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
172
+ offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
173
+ tensor_count, kv_count = temp_counts
174
+ offs = self._build_fields(offs, kv_count)
175
+
176
+ # Build Tensor Info Fields
177
+ offs, tensors_fields = self._build_tensor_info(offs, tensor_count)
178
+ new_align = self.fields.get('general.alignment')
179
+ if new_align is not None:
180
+ if new_align.types != [GGUFValueType.UINT32]:
181
+ raise ValueError('Bad type for general.alignment field')
182
+ self.alignment = new_align.parts[-1][0]
183
+ padding = offs % self.alignment
184
+ if padding != 0:
185
+ offs += self.alignment - padding
186
+ self.data_offset = offs
187
+ self._build_tensors(offs, tensors_fields)
188
+
189
+ _DT = TypeVar('_DT', bound = npt.DTypeLike)
190
+
191
+ # Fetch a key/value metadata field by key.
192
+ def get_field(self, key: str) -> Union[ReaderField, None]:
193
+ return self.fields.get(key, None)
194
+
195
+ # Fetch a tensor from the list by index.
196
+ def get_tensor(self, idx: int) -> ReaderTensor:
197
+ return self.tensors[idx]
198
+
199
+ def _get(
200
+ self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
201
+ ) -> npt.NDArray[Any]:
202
+ count = int(count)
203
+ itemsize = int(np.empty([], dtype = dtype).itemsize)
204
+ end_offs = offset + itemsize * count
205
+ arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
206
+ return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order))
207
+
208
+ def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
209
+ if field.name in self.fields:
210
+ # TODO: add option to generate error on duplicate keys
211
+ # raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
212
+
213
+ logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
214
+ self.fields[field.name + '_{}'.format(field.offset)] = field
215
+ else:
216
+ self.fields[field.name] = field
217
+ return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
218
+
219
+ def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
220
+ slen = self._get(offset, np.uint64)
221
+ return slen, self._get(offset + 8, np.uint8, slen[0])
222
+
223
+ def _get_field_parts(
224
+ self, orig_offs: int, raw_type: int,
225
+ ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
226
+ offs = orig_offs
227
+ types: list[GGUFValueType] = []
228
+ gtype = GGUFValueType(raw_type)
229
+ types.append(gtype)
230
+ # Handle strings.
231
+ if gtype == GGUFValueType.STRING:
232
+ sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
233
+ size = sum(int(part.nbytes) for part in sparts)
234
+ return size, sparts, [1], types
235
+ # Check if it's a simple scalar type.
236
+ nptype = self.gguf_scalar_to_np.get(gtype)
237
+ if nptype is not None:
238
+ val = self._get(offs, nptype)
239
+ return int(val.nbytes), [val], [0], types
240
+ # Handle arrays.
241
+ if gtype == GGUFValueType.ARRAY:
242
+ raw_itype = self._get(offs, np.uint32) # <-- Adicionado np.uint32 aqui
243
+ offs += int(raw_itype.nbytes)
244
+ alen = self._get(offs, np.uint64) # <-- GGUFv3 usa uint64 para tamanho de array
245
+ offs += int(alen.nbytes)
246
+ aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
247
+ data_idxs: list[int] = []
248
+ # FIXME: Handle multi-dimensional arrays properly instead of flattening
249
+ for idx in range(int(alen[0])):
250
+ curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
251
+ if idx == 0:
252
+ types += curr_types
253
+ idxs_offs = len(aparts)
254
+ aparts += curr_parts
255
+ data_idxs += [i + idxs_offs for i in curr_idxs]
256
+ offs += curr_size
257
+ return offs - orig_offs, aparts, data_idxs, types # We can't deal with this one.
258
+ raise ValueError(f'Unknown/unhandled field type {gtype}')
259
+
260
+ def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
261
+ offs = orig_offs
262
+
263
+ # Get Tensor Name
264
+ name_len, name_data = self._get_str(offs)
265
+ offs += int(name_len.nbytes + name_data.nbytes)
266
+
267
+ # Get Tensor Dimensions Count
268
+ n_dims = self._get(offs, np.uint32)
269
+ offs += int(n_dims.nbytes)
270
+
271
+ # Get Tensor Dimension Array
272
+ dims = self._get(offs, np.uint64, n_dims[0])
273
+ offs += int(dims.nbytes)
274
+
275
+ # Get Tensor Encoding Scheme Type
276
+ raw_dtype = self._get(offs, np.uint32)
277
+ offs += int(raw_dtype.nbytes)
278
+
279
+ # Get Tensor Offset
280
+ offset_tensor = self._get(offs, np.uint64)
281
+ offs += int(offset_tensor.nbytes)
282
+
283
+ return ReaderField(
284
+ orig_offs,
285
+ str(bytes(name_data), encoding = 'utf-8'),
286
+ [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
287
+ [1, 3, 4, 5],
288
+ )
289
+
290
+ def _build_fields(self, offs: int, count: int) -> int:
291
+ for _ in range(count):
292
+ orig_offs = offs
293
+ kv_klen, kv_kdata = self._get_str(offs)
294
+ offs += int(kv_klen.nbytes + kv_kdata.nbytes)
295
+ raw_kv_type = self._get(offs, np.uint32)
296
+ offs += int(raw_kv_type.nbytes)
297
+ parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
298
+ idxs_offs = len(parts)
299
+ field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
300
+ parts += field_parts
301
+ self._push_field(ReaderField(
302
+ orig_offs,
303
+ str(bytes(kv_kdata), encoding = 'utf-8'),
304
+ parts,
305
+ [idx + idxs_offs for idx in field_idxs],
306
+ field_types,
307
+ ), skip_sum = True)
308
+ offs += field_size
309
+ return offs
310
+
311
+ def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
312
+ tensor_fields = []
313
+ for _ in range(count):
314
+ field = self._get_tensor_info_field(offs)
315
+ offs += sum(int(part.nbytes) for part in field.parts)
316
+ tensor_fields.append(field)
317
+ return offs, tensor_fields
318
+
319
+ def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
320
+ tensors = []
321
+ tensor_names = set() # keep track of name to prevent duplicated tensors
322
+ for field in fields:
323
+ _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
324
+ # check if there's any tensor having same name already in the list
325
+ tensor_name = str(bytes(name_data), encoding = 'utf-8')
326
+ if tensor_name in tensor_names:
327
+ raise ValueError(f'Found duplicated tensor with name {tensor_name}')
328
+ tensor_names.add(tensor_name)
329
+ ggml_type = GGMLQuantizationType(raw_dtype[0])
330
+ n_elems = int(np.prod(dims))
331
+ np_dims = tuple(reversed(dims.tolist()))
332
+ block_size, type_size = GGML_QUANT_SIZES[ggml_type]
333
+ n_bytes = n_elems * type_size // block_size
334
+ data_offs = int(start_offs + offset_tensor[0])
335
+ item_type: npt.DTypeLike
336
+ if ggml_type == GGMLQuantizationType.F16:
337
+ item_count = n_elems
338
+ item_type = np.float16
339
+ elif ggml_type == GGMLQuantizationType.F32:
340
+ item_count = n_elems
341
+ item_type = np.float32
342
+ elif ggml_type == GGMLQuantizationType.F64:
343
+ item_count = n_elems
344
+ item_type = np.float64
345
+ elif ggml_type == GGMLQuantizationType.I8:
346
+ item_count = n_elems
347
+ item_type = np.int8
348
+ elif ggml_type == GGMLQuantizationType.I16:
349
+ item_count = n_elems
350
+ item_type = np.int16
351
+ elif ggml_type == GGMLQuantizationType.I32:
352
+ item_count = n_elems
353
+ item_type = np.int32
354
+ elif ggml_type == GGMLQuantizationType.I64:
355
+ item_count = n_elems
356
+ item_type = np.int64
357
+ else:
358
+ item_count = n_bytes
359
+ item_type = np.uint8
360
+ np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
361
+ tensors.append(ReaderTensor(
362
+ name = tensor_name,
363
+ tensor_type = ggml_type,
364
+ shape = dims,
365
+ n_elements = n_elems,
366
+ n_bytes = n_bytes,
367
+ data_offset = data_offs,
368
+ data = self._get(data_offs, item_type, item_count).reshape(np_dims),
369
+ field = field,
370
+ ))
371
+ self.tensors = tensors
gguf_writer.py ADDED
@@ -0,0 +1,1276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #BRUNO BECKER / OFFELLIA 2026
2
+ #brunoconta1980@gmail.com
3
+ #brunoconta1980@hotmail.com
4
+ # X @Brunoxuser
5
+
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ import os
10
+ import shutil
11
+ import struct
12
+ import sys
13
+ import tempfile
14
+ from dataclasses import dataclass
15
+ from enum import Enum, auto
16
+ from math import prod
17
+ from pathlib import Path
18
+ from io import BufferedWriter
19
+ from typing import IO, Any, Sequence, Mapping
20
+ from string import ascii_letters, digits
21
+
22
+ import numpy as np
23
+
24
+ from .constants import (
25
+ GGUF_DEFAULT_ALIGNMENT,
26
+ GGUF_MAGIC,
27
+ GGUF_VERSION,
28
+ GGMLQuantizationType,
29
+ GGUFEndian,
30
+ GGUFValueType,
31
+ Keys,
32
+ RopeScalingType,
33
+ PoolingType,
34
+ TokenType,
35
+ ExpertGatingFuncType,
36
+ )
37
+
38
+ from .quants import quant_shape_from_byte_shape
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
44
+
45
+
46
+ @dataclass
47
+ class TensorInfo:
48
+ shape: Sequence[int]
49
+ dtype: GGMLQuantizationType
50
+ nbytes: int
51
+ tensor: np.ndarray[Any, Any] | None = None
52
+
53
+
54
+ @dataclass
55
+ class GGUFValue:
56
+ value: Any
57
+ type: GGUFValueType
58
+ sub_type: GGUFValueType | None = None
59
+
60
+
61
+ class WriterState(Enum):
62
+ NO_FILE = auto()
63
+ EMPTY = auto()
64
+ HEADER = auto()
65
+ KV_DATA = auto()
66
+ TI_DATA = auto()
67
+ WEIGHTS = auto()
68
+
69
+
70
+ class GGUFWriter:
71
+ fout: list[BufferedWriter] | None
72
+ path: Path | None
73
+ temp_file: tempfile.SpooledTemporaryFile[bytes] | None
74
+ tensors: list[dict[str, TensorInfo]]
75
+ kv_data: list[dict[str, GGUFValue]]
76
+ state: WriterState
77
+ _simple_value_packing = {
78
+ GGUFValueType.UINT8: "B",
79
+ GGUFValueType.INT8: "b",
80
+ GGUFValueType.UINT16: "H",
81
+ GGUFValueType.INT16: "h",
82
+ GGUFValueType.UINT32: "I",
83
+ GGUFValueType.INT32: "i",
84
+ GGUFValueType.FLOAT32: "f",
85
+ GGUFValueType.UINT64: "Q",
86
+ GGUFValueType.INT64: "q",
87
+ GGUFValueType.FLOAT64: "d",
88
+ GGUFValueType.BOOL: "?",
89
+ }
90
+
91
+ def __init__(
92
+ self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
93
+ split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
94
+ ):
95
+ self.fout = None
96
+ self.path = Path(path) if path else None
97
+ self.arch = arch
98
+ self.endianess = endianess
99
+ self.data_alignment = GGUF_DEFAULT_ALIGNMENT
100
+ self.use_temp_file = use_temp_file
101
+ self.temp_file = None
102
+ self.tensors = [{}]
103
+ self.kv_data = [{}]
104
+ self.split_max_tensors = split_max_tensors
105
+ self.split_max_size = split_max_size
106
+ self.dry_run = dry_run
107
+ self.small_first_shard = small_first_shard
108
+ logger.info("gguf: This GGUF file is for {0} Endian only".format(
109
+ "Big" if self.endianess == GGUFEndian.BIG else "Little",
110
+ ))
111
+ self.state = WriterState.NO_FILE
112
+
113
+ if self.small_first_shard:
114
+ self.tensors.append({})
115
+
116
+ self.add_architecture()
117
+
118
+ def get_total_parameter_count(self) -> tuple[int, int, int, int]:
119
+ total_params = 0
120
+ shared_params = 0
121
+ expert_params = 0
122
+
123
+ expert_sum = 0
124
+ n_expert_tensors = 0
125
+
126
+ last_lora_a: tuple[str, TensorInfo] | None = None
127
+
128
+ for tensors in self.tensors:
129
+ for name, info in tensors.items():
130
+
131
+ shape = info.shape
132
+
133
+ if name.endswith(".lora_a"):
134
+ last_lora_a = (name, info)
135
+ continue
136
+ elif name.endswith(".lora_b"):
137
+ if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
138
+ # Bail when the LoRA pair can't be found trivially
139
+ logger.warning("can't measure LoRA size correctly, tensor order is unusual")
140
+ return 0, 0, 0, 0
141
+ else:
142
+ shape = (*shape[:-1], last_lora_a[1].shape[-1])
143
+
144
+ size = prod(shape)
145
+
146
+ if "_exps." in name:
147
+ expert_count = shape[-2 if ".bias" in name else -3]
148
+ expert_params += (size // expert_count)
149
+ expert_sum += expert_count
150
+ n_expert_tensors += 1
151
+ else:
152
+ shared_params += size
153
+
154
+ total_params += size
155
+
156
+ # Hopefully this should work even for variable-expert-count models
157
+ expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
158
+
159
+ # Negate the total to signal it's likely not exact
160
+ if last_lora_a is not None:
161
+ total_params = -total_params
162
+
163
+ # NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
164
+ return total_params, shared_params, expert_params, expert_count
165
+
166
+ def format_shard_names(self, path: Path) -> list[Path]:
167
+ if len(self.tensors) == 1:
168
+ return [path]
169
+ return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))]
170
+
171
+ def open_output_file(self, path: Path | None = None) -> None:
172
+ if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
173
+ # allow calling this multiple times as long as the path is the same
174
+ return
175
+
176
+ if self.state is not WriterState.NO_FILE:
177
+ raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
178
+
179
+ if path is not None:
180
+ self.path = path
181
+
182
+ if self.path is not None:
183
+ filenames = self.print_plan()
184
+ self.fout = [open(filename, "wb") for filename in filenames]
185
+ self.state = WriterState.EMPTY
186
+
187
+ def print_plan(self) -> list[Path]:
188
+ logger.info("Writing the following files:")
189
+ assert self.path is not None
190
+ filenames = self.format_shard_names(self.path)
191
+ assert len(filenames) == len(self.tensors)
192
+ for name, tensors in zip(filenames, self.tensors):
193
+ logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
194
+
195
+ if self.dry_run:
196
+ logger.info("Dry run, not writing files")
197
+ for name in filenames:
198
+ print(name) # noqa: NP100
199
+ exit()
200
+
201
+ return filenames
202
+
203
+ def add_shard_kv_data(self) -> None:
204
+ if len(self.tensors) == 1:
205
+ return
206
+
207
+ total_tensors = sum(len(t) for t in self.tensors)
208
+ assert self.fout is not None
209
+ total_splits = len(self.fout)
210
+ self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits))
211
+ for i, kv_data in enumerate(self.kv_data):
212
+ kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
213
+ kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
214
+ kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
215
+
216
+ def write_header_to_file(self, path: Path | None = None) -> None:
217
+ if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0):
218
+ logger.warning("Model fails split requirements, not splitting")
219
+
220
+ self.open_output_file(path)
221
+
222
+ if self.state is not WriterState.EMPTY:
223
+ raise ValueError(f'Expected output file to be empty, got {self.state}')
224
+
225
+ assert self.fout is not None
226
+ assert len(self.fout) == len(self.tensors)
227
+ assert len(self.kv_data) == 1
228
+
229
+ self.add_shard_kv_data()
230
+
231
+ for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data):
232
+ fout.write(self._pack("<I", GGUF_MAGIC, skip_pack_prefix = True))
233
+ fout.write(self._pack("I", GGUF_VERSION))
234
+ fout.write(self._pack("Q", len(tensors)))
235
+ fout.write(self._pack("Q", len(kv_data)))
236
+ fout.flush()
237
+ self.state = WriterState.HEADER
238
+
239
+ def write_kv_data_to_file(self) -> None:
240
+ if self.state is not WriterState.HEADER:
241
+ raise ValueError(f'Expected output file to contain the header, got {self.state}')
242
+ assert self.fout is not None
243
+
244
+ for fout, kv_data in zip(self.fout, self.kv_data):
245
+ kv_bytes = bytearray()
246
+
247
+ for key, val in kv_data.items():
248
+ kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
249
+ kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type)
250
+
251
+ fout.write(kv_bytes)
252
+
253
+ self.flush()
254
+ self.state = WriterState.KV_DATA
255
+
256
+ def write_ti_data_to_file(self) -> None:
257
+ if self.state is not WriterState.KV_DATA:
258
+ raise ValueError(f'Expected output file to contain KV data, got {self.state}')
259
+ assert self.fout is not None
260
+
261
+ for fout, tensors in zip(self.fout, self.tensors):
262
+ ti_data = bytearray()
263
+ offset_tensor = 0
264
+
265
+ for name, ti in tensors.items():
266
+ ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
267
+ n_dims = len(ti.shape)
268
+ ti_data += self._pack("I", n_dims)
269
+ for j in range(n_dims):
270
+ ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
271
+ ti_data += self._pack("I", ti.dtype)
272
+ ti_data += self._pack("Q", offset_tensor)
273
+ offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
274
+
275
+ fout.write(ti_data)
276
+ fout.flush()
277
+ self.state = WriterState.TI_DATA
278
+
279
+ def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None:
280
+ if any(key in kv_data for kv_data in self.kv_data):
281
+ logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}')
282
+
283
+ self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
284
+
285
+ def add_uint8(self, key: str, val: int) -> None:
286
+ self.add_key_value(key,val, GGUFValueType.UINT8)
287
+
288
+ def add_int8(self, key: str, val: int) -> None:
289
+ self.add_key_value(key, val, GGUFValueType.INT8)
290
+
291
+ def add_uint16(self, key: str, val: int) -> None:
292
+ self.add_key_value(key, val, GGUFValueType.UINT16)
293
+
294
+ def add_int16(self, key: str, val: int) -> None:
295
+ self.add_key_value(key, val, GGUFValueType.INT16)
296
+
297
+ def add_uint32(self, key: str, val: int) -> None:
298
+ self.add_key_value(key, val, GGUFValueType.UINT32)
299
+
300
+ def add_int32(self, key: str, val: int) -> None:
301
+ self.add_key_value(key, val, GGUFValueType.INT32)
302
+
303
+ def add_float32(self, key: str, val: float) -> None:
304
+ self.add_key_value(key, val, GGUFValueType.FLOAT32)
305
+
306
+ def add_uint64(self, key: str, val: int) -> None:
307
+ self.add_key_value(key, val, GGUFValueType.UINT64)
308
+
309
+ def add_int64(self, key: str, val: int) -> None:
310
+ self.add_key_value(key, val, GGUFValueType.INT64)
311
+
312
+ def add_float64(self, key: str, val: float) -> None:
313
+ self.add_key_value(key, val, GGUFValueType.FLOAT64)
314
+
315
+ def add_bool(self, key: str, val: bool) -> None:
316
+ self.add_key_value(key, val, GGUFValueType.BOOL)
317
+
318
+ def add_string(self, key: str, val: str) -> None:
319
+ if not val:
320
+ return
321
+ self.add_key_value(key, val, GGUFValueType.STRING)
322
+
323
+ def add_array(self, key: str, val: Sequence[Any]) -> None:
324
+ if len(val) == 0:
325
+ return
326
+ self.add_key_value(key, val, GGUFValueType.ARRAY)
327
+
328
+ @staticmethod
329
+ def ggml_pad(x: int, n: int) -> int:
330
+ return ((x + n - 1) // n) * n
331
+
332
+ def add_tensor_info(
333
+ self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
334
+ tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
335
+ ) -> None:
336
+ if self.state is not WriterState.NO_FILE:
337
+ raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
338
+
339
+ if any(name in tensors for tensors in self.tensors):
340
+ raise ValueError(f'Duplicated tensor name {name!r}')
341
+
342
+ if raw_dtype is None:
343
+ if tensor_dtype == np.float16:
344
+ dtype = GGMLQuantizationType.F16
345
+ elif tensor_dtype == np.float32:
346
+ dtype = GGMLQuantizationType.F32
347
+ elif tensor_dtype == np.float64:
348
+ dtype = GGMLQuantizationType.F64
349
+ elif tensor_dtype == np.int8:
350
+ dtype = GGMLQuantizationType.I8
351
+ elif tensor_dtype == np.int16:
352
+ dtype = GGMLQuantizationType.I16
353
+ elif tensor_dtype == np.int32:
354
+ dtype = GGMLQuantizationType.I32
355
+ elif tensor_dtype == np.int64:
356
+ dtype = GGMLQuantizationType.I64
357
+ else:
358
+ raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
359
+ else:
360
+ dtype = raw_dtype
361
+ if tensor_dtype == np.uint8:
362
+ tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
363
+
364
+ # make sure there is at least one tensor before splitting
365
+ if len(self.tensors[-1]) > 0:
366
+ if ( # split when over tensor limit
367
+ self.split_max_tensors != 0
368
+ and len(self.tensors[-1]) >= self.split_max_tensors
369
+ ) or ( # split when over size limit
370
+ self.split_max_size != 0
371
+ and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size
372
+ ):
373
+ self.tensors.append({})
374
+
375
+ self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
376
+
377
+ def add_tensor(
378
+ self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
379
+ raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None
380
+ ) -> None:
381
+ # if tensor endianness is not passed, assume it's native to system
382
+ if tensor_endianess is None:
383
+ tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
384
+
385
+ if tensor_endianess != self.endianess:
386
+ # Don't byteswap inplace since lazy copies cannot handle it
387
+ tensor = tensor.byteswap(inplace=False)
388
+ if self.use_temp_file and self.temp_file is None:
389
+ fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
390
+ fp.seek(0)
391
+ self.temp_file = fp
392
+
393
+ shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
394
+ self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype)
395
+
396
+ if self.temp_file is None:
397
+ self.tensors[-1][name].tensor = tensor
398
+ return
399
+
400
+ tensor.tofile(self.temp_file)
401
+ self.write_padding(self.temp_file, tensor.nbytes)
402
+
403
+ def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
404
+ pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
405
+ if pad != 0:
406
+ fp.write(bytes([0] * pad))
407
+
408
+ def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None:
409
+ if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
410
+ raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
411
+ assert self.fout is not None
412
+
413
+ # if tensor endianness is not passed, assume it's native to system
414
+ if tensor_endianess is None:
415
+ tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
416
+
417
+ if tensor_endianess != self.endianess:
418
+ # Don't byteswap inplace since lazy copies cannot handle it
419
+ tensor = tensor.byteswap(inplace=False)
420
+
421
+ file_id = -1
422
+ for i, tensors in enumerate(self.tensors):
423
+ if len(tensors) > 0:
424
+ file_id = i
425
+ break
426
+
427
+ fout = self.fout[file_id]
428
+
429
+ # pop the first tensor info
430
+ # TODO: cleaner way to get the first key
431
+ first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0]
432
+ ti = self.tensors[file_id].pop(first_tensor_name)
433
+ assert ti.nbytes == tensor.nbytes
434
+
435
+ self.write_padding(fout, fout.tell())
436
+ tensor.tofile(fout)
437
+ self.write_padding(fout, tensor.nbytes)
438
+
439
+ self.state = WriterState.WEIGHTS
440
+
441
+ def write_tensors_to_file(self, *, progress: bool = False) -> None:
442
+ self.write_ti_data_to_file()
443
+
444
+ assert self.fout is not None
445
+
446
+ for fout in self.fout:
447
+ self.write_padding(fout, fout.tell())
448
+
449
+ if self.temp_file is None:
450
+ shard_bar = None
451
+ bar = None
452
+
453
+ if progress:
454
+ from tqdm import tqdm
455
+
456
+ total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
457
+
458
+ if len(self.fout) > 1:
459
+ shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
460
+ bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
461
+
462
+ for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
463
+ if shard_bar is not None:
464
+ shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
465
+ total = sum(ti.nbytes for ti in tensors.values())
466
+ shard_bar.reset(total=(total if total > 0 else None))
467
+
468
+ # relying on the fact that Python dicts preserve insertion order (since 3.7)
469
+ for ti in tensors.values():
470
+ assert ti.tensor is not None # can only iterate once over the tensors
471
+ assert ti.tensor.nbytes == ti.nbytes
472
+ ti.tensor.tofile(fout)
473
+ if shard_bar is not None:
474
+ shard_bar.update(ti.nbytes)
475
+ if bar is not None:
476
+ bar.update(ti.nbytes)
477
+ self.write_padding(fout, ti.nbytes)
478
+ ti.tensor = None
479
+ else:
480
+ self.temp_file.seek(0)
481
+
482
+ shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1])
483
+ self.flush()
484
+ self.temp_file.close()
485
+
486
+ self.state = WriterState.WEIGHTS
487
+
488
+ def flush(self) -> None:
489
+ assert self.fout is not None
490
+ for fout in self.fout:
491
+ fout.flush()
492
+
493
+ def close(self) -> None:
494
+ if self.fout is not None:
495
+ for fout in self.fout:
496
+ fout.close()
497
+ self.fout = None
498
+
499
+ def add_type(self, type_name: str) -> None:
500
+ self.add_string(Keys.General.TYPE, type_name)
501
+
502
+ def add_architecture(self) -> None:
503
+ self.add_string(Keys.General.ARCHITECTURE, self.arch)
504
+
505
+ def add_quantization_version(self, quantization_version: int) -> None:
506
+ self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
507
+
508
+ def add_custom_alignment(self, alignment: int) -> None:
509
+ self.data_alignment = alignment
510
+ self.add_uint32(Keys.General.ALIGNMENT, alignment)
511
+
512
+ def add_file_type(self, ftype: int) -> None:
513
+ self.add_uint32(Keys.General.FILE_TYPE, ftype)
514
+
515
+ def add_sampling_sequence(self, sequence: str) -> None:
516
+ self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence)
517
+
518
+ def add_sampling_top_k(self, top_k: int) -> None:
519
+ self.add_int41(Keys.General.SAMPLING_TOP_K, top_k)
520
+
521
+ def add_sampling_top_p(self, top_p: float) -> None:
522
+ self.add_float41(Keys.General.SAMPLING_TOP_P, top_p)
523
+
524
+ def add_sampling_min_p(self, min_p: float) -> None:
525
+ self.add_float41(Keys.General.SAMPLING_MIN_P, min_p)
526
+
527
+ def add_sampling_xtc_probability(self, xtc_probability: float) -> None:
528
+ self.add_float41(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability)
529
+
530
+ def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None:
531
+ self.add_float41(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold)
532
+
533
+ def add_sampling_temp(self, temp: float) -> None:
534
+ self.add_float41(Keys.General.SAMPLING_TEMP, temp)
535
+
536
+ def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None:
537
+ self.add_int41(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n)
538
+
539
+ def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None:
540
+ self.add_float41(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat)
541
+
542
+ def add_sampling_mirostat(self, mirostat: int) -> None:
543
+ self.add_int41(Keys.General.SAMPLING_MIROSTAT, mirostat)
544
+
545
+ def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None:
546
+ self.add_float41(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau)
547
+
548
+ def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None:
549
+ self.add_float41(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta)
550
+
551
+ def add_name(self, name: str) -> None:
552
+ self.add_string(Keys.General.NAME, name)
553
+
554
+ def add_author(self, author: str) -> None:
555
+ self.add_string(Keys.General.AUTHOR, author)
556
+
557
+ def add_version(self, version: str) -> None:
558
+ self.add_string(Keys.General.VERSION, version)
559
+
560
+ def add_organization(self, organization: str) -> None:
561
+ self.add_string(Keys.General.ORGANIZATION, organization)
562
+
563
+ def add_finetune(self, finetune: str) -> None:
564
+ self.add_string(Keys.General.FINETUNE, finetune)
565
+
566
+ def add_basename(self, basename: str) -> None:
567
+ self.add_string(Keys.General.BASENAME, basename)
568
+
569
+ def add_description(self, description: str) -> None:
570
+ self.add_string(Keys.General.DESCRIPTION, description)
571
+
572
+ def add_quantized_by(self, quantized: str) -> None:
573
+ self.add_string(Keys.General.QUANTIZED_BY, quantized)
574
+
575
+ def add_size_label(self, size_label: str) -> None:
576
+ self.add_string(Keys.General.SIZE_LABEL, size_label)
577
+
578
+ def add_license(self, license: str) -> None:
579
+ self.add_string(Keys.General.LICENSE, license)
580
+
581
+ def add_license_name(self, license: str) -> None:
582
+ self.add_string(Keys.General.LICENSE_NAME, license)
583
+
584
+ def add_license_link(self, license: str) -> None:
585
+ self.add_string(Keys.General.LICENSE_LINK, license)
586
+
587
+ def add_url(self, url: str) -> None:
588
+ self.add_string(Keys.General.URL, url)
589
+
590
+ def add_doi(self, doi: str) -> None:
591
+ self.add_string(Keys.General.DOI, doi)
592
+
593
+ def add_uuid(self, uuid: str) -> None:
594
+ self.add_string(Keys.General.UUID, uuid)
595
+
596
+ def add_repo_url(self, repo_url: str) -> None:
597
+ self.add_string(Keys.General.REPO_URL, repo_url)
598
+
599
+ def add_source_url(self, url: str) -> None:
600
+ self.add_string(Keys.General.SOURCE_URL, url)
601
+
602
+ def add_source_doi(self, doi: str) -> None:
603
+ self.add_string(Keys.General.SOURCE_DOI, doi)
604
+
605
+ def add_source_uuid(self, uuid: str) -> None:
606
+ self.add_string(Keys.General.SOURCE_UUID, uuid)
607
+
608
+ def add_source_repo_url(self, repo_url: str) -> None:
609
+ self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
610
+
611
+ def add_base_model_count(self, source_count: int) -> None:
612
+ self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
613
+
614
+ def add_base_model_name(self, source_id: int, name: str) -> None:
615
+ self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
616
+
617
+ def add_base_model_author(self, source_id: int, author: str) -> None:
618
+ self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
619
+
620
+ def add_base_model_version(self, source_id: int, version: str) -> None:
621
+ self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
622
+
623
+ def add_base_model_organization(self, source_id: int, organization: str) -> None:
624
+ self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
625
+
626
+ def add_base_model_description(self, source_id: int, description: str) -> None:
627
+ self.add_string(Keys.General.BASE_MODEL_DESCRIPTION.format(id=source_id), description)
628
+
629
+ def add_base_model_url(self, source_id: int, url: str) -> None:
630
+ self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
631
+
632
+ def add_base_model_doi(self, source_id: int, doi: str) -> None:
633
+ self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
634
+
635
+ def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
636
+ self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
637
+
638
+ def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
639
+ self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
640
+
641
+ def add_dataset_count(self, source_count: int) -> None:
642
+ self.add_uint32(Keys.General.DATASET_COUNT, source_count)
643
+
644
+ def add_dataset_name(self, source_id: int, name: str) -> None:
645
+ self.add_string(Keys.General.DATASET_NAME.format(id=source_id), name)
646
+
647
+ def add_dataset_author(self, source_id: int, author: str) -> None:
648
+ self.add_string(Keys.General.DATASET_AUTHOR.format(id=source_id), author)
649
+
650
+ def add_dataset_version(self, source_id: int, version: str) -> None:
651
+ self.add_string(Keys.General.DATASET_VERSION.format(id=source_id), version)
652
+
653
+ def add_dataset_organization(self, source_id: int, organization: str) -> None:
654
+ self.add_string(Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization)
655
+
656
+ def add_dataset_description(self, source_id: int, description: str) -> None:
657
+ self.add_string(Keys.General.DATASET_DESCRIPTION.format(id=source_id), description)
658
+
659
+ def add_dataset_url(self, source_id: int, url: str) -> None:
660
+ self.add_string(Keys.General.DATASET_URL.format(id=source_id), url)
661
+
662
+ def add_dataset_doi(self, source_id: int, doi: str) -> None:
663
+ self.add_string(Keys.General.DATASET_DOI.format(id=source_id), doi)
664
+
665
+ def add_dataset_uuid(self, source_id: int, uuid: str) -> None:
666
+ self.add_string(Keys.General.DATASET_UUID.format(id=source_id), uuid)
667
+
668
+ def add_dataset_repo_url(self, source_id: int, repo_url: str) -> None:
669
+ self.add_string(Keys.General.DATASET_REPO_URL.format(id=source_id), repo_url)
670
+
671
+ def add_tags(self, tags: Sequence[str]) -> None:
672
+ self.add_array(Keys.General.TAGS, tags)
673
+
674
+ def add_languages(self, languages: Sequence[str]) -> None:
675
+ self.add_array(Keys.General.LANGUAGES, languages)
676
+
677
+ def add_tensor_data_layout(self, layout: str) -> None:
678
+ self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
679
+
680
+ def add_vocab_size(self, size: int) -> None:
681
+ self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
682
+
683
+ def add_context_length(self, length: int) -> None:
684
+ self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)
685
+
686
+ def add_embedding_length(self, length: int) -> None:
687
+ self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
688
+
689
+ def add_embedding_length_out(self, length: int) -> None:
690
+ self.add_uint32(Keys.LLM.EMBEDDING_LENGTH_OUT.format(arch=self.arch), length)
691
+
692
+ def add_features_length(self, length: int) -> None:
693
+ self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length)
694
+
695
+ def add_posnet_embedding_length(self, length: int) -> None:
696
+ self.add_uint32(Keys.PosNet.EMBEDDING_LENGTH.format(arch=self.arch), length)
697
+
698
+ def add_posnet_block_count(self, length: int) -> None:
699
+ self.add_uint32(Keys.PosNet.BLOCK_COUNT.format(arch=self.arch), length)
700
+
701
+ def add_convnext_embedding_length(self, length: int) -> None:
702
+ self.add_uint32(Keys.ConvNext.EMBEDDING_LENGTH.format(arch=self.arch), length)
703
+
704
+ def add_convnext_block_count(self, length: int) -> None:
705
+ self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length)
706
+
707
+ def add_shortconv_l_cache(self, length: int) -> None:
708
+ self.add_uint32(Keys.ShortConv.L_CACHE.format(arch=self.arch), length)
709
+
710
+ def add_block_count(self, length: int) -> None:
711
+ self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
712
+
713
+ def add_leading_dense_block_count(self, length: int) -> None:
714
+ self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)
715
+
716
+ def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
717
+ if isinstance(length, int):
718
+ self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
719
+ else:
720
+ self.add_array(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
721
+
722
+ def add_expert_feed_forward_length(self, length: int) -> None:
723
+ self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
724
+
725
+ def add_expert_shared_feed_forward_length(self, length: int) -> None:
726
+ self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
727
+
728
+ def add_expert_chunk_feed_forward_length(self, length: int) -> None:
729
+ self.add_uint32(Keys.LLM.EXPERT_CHUNK_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
730
+
731
+ def add_parallel_residual(self, use: bool) -> None:
732
+ self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
733
+
734
+ def add_decoder_start_token_id(self, id: int) -> None:
735
+ self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
736
+
737
+ def add_decoder_block_count(self, value: int) -> None:
738
+ self.add_uint32(Keys.LLM.DECODER_BLOCK_COUNT.format(arch=self.arch), value)
739
+
740
+ def add_embedding_length_per_layer_input(self, value: int) -> None:
741
+ self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)
742
+
743
+ def add_altup_active_idx(self, val: int) -> None:
744
+ self.add_uint32(Keys.LLM.ALTUP_ACTIVE_IDX.format(arch=self.arch), val)
745
+
746
+ def add_altup_num_inputs(self, val: int) -> None:
747
+ self.add_uint32(Keys.LLM.ALTUP_NUM_INPUTS.format(arch=self.arch), val)
748
+
749
+ def add_activation_sparsity_scale(self, values: Sequence[float]) -> None:
750
+ self.add_array(Keys.LLM.ACTIVATION_SPARSITY_SCALE.format(arch=self.arch), values)
751
+
752
+ def add_head_count(self, count: int | Sequence[int]) -> None:
753
+ if isinstance(count, int):
754
+ self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
755
+ else:
756
+ self.add_array(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
757
+
758
+ def add_head_count_kv(self, count: int | Sequence[int]) -> None:
759
+ if isinstance(count, int):
760
+ self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
761
+ else:
762
+ self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
763
+
764
+ def add_key_length(self, length: int) -> None:
765
+ self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
766
+
767
+ def add_value_length(self, length: int) -> None:
768
+ self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
769
+
770
+ def add_key_length_mla(self, length: int) -> None:
771
+ self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length)
772
+
773
+ def add_value_length_mla(self, length: int) -> None:
774
+ self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
775
+
776
+ def add_max_alibi_bias(self, bias: float) -> None:
777
+ self.add_float41(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
778
+
779
+ def add_clamp_kqv(self, value: float) -> None:
780
+ self.add_float41(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
781
+
782
+ def add_shared_kv_layers(self, value: int) -> None:
783
+ self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
784
+
785
+ def add_sliding_window_pattern(self, value: int | Sequence[bool]) -> None:
786
+ key = Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch)
787
+ if isinstance(value, int):
788
+ self.add_uint32(key, value)
789
+ else:
790
+ self.add_array(key, value)
791
+
792
+ def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None:
793
+ self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f)
794
+ self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f)
795
+
796
+ def add_logit_scale(self, value: float) -> None:
797
+ self.add_float41(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
798
+
799
+ def add_attn_logit_softcapping(self, value: float) -> None:
800
+ self.add_float41(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
801
+
802
+ def add_router_logit_softcapping(self, value: float) -> None:
803
+ self.add_float41(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
804
+
805
+ def add_final_logit_softcapping(self, value: float) -> None:
806
+ self.add_float41(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
807
+
808
+ def add_expert_count(self, count: int) -> None:
809
+ self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
810
+
811
+ def add_expert_used_count(self, count: int) -> None:
812
+ self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
813
+
814
+ def add_expert_shared_count(self, count: int) -> None:
815
+ self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
816
+
817
+ def add_expert_group_count(self, count: int) -> None:
818
+ self.add_uint32(Keys.LLM.EXPERT_GROUP_COUNT.format(arch=self.arch), count)
819
+
820
+ def add_expert_group_used_count(self, count: int) -> None:
821
+ self.add_uint32(Keys.LLM.EXPERT_GROUP_USED_COUNT.format(arch=self.arch), count)
822
+
823
+ def add_expert_weights_scale(self, value: float) -> None:
824
+ self.add_float41(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
825
+
826
+ def add_expert_weights_norm(self, value: bool) -> None:
827
+ self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value)
828
+
829
+ def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
830
+ self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
831
+
832
+ def add_expert_group_scale(self, value: float) -> None:
833
+ self.add_float41(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value)
834
+
835
+ def add_experts_per_group(self, count: int) -> None:
836
+ self.add_uint32(Keys.LLM.EXPERTS_PER_GROUP.format(arch=self.arch), count)
837
+
838
+ def add_moe_every_n_layers(self, value: int) -> None:
839
+ self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
840
+
841
+ def add_nextn_predict_layers(self, count: int) -> None:
842
+ self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count)
843
+
844
+ def add_swin_norm(self, value: bool) -> None:
845
+ self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
846
+
847
+ def add_rescale_every_n_layers(self, count: int) -> None:
848
+ self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
849
+
850
+ def add_time_mix_extra_dim(self, dim: int) -> None:
851
+ self.add_uint32(Keys.LLM.TIME_MIX_EXTRA_DIM.format(arch=self.arch), dim)
852
+
853
+ def add_time_decay_extra_dim(self, dim: int) -> None:
854
+ self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim)
855
+
856
+ def add_residual_scale(self, value: float) -> None:
857
+ self.add_float41(Keys.LLM.RESIDUAL_SCALE.format(arch=self.arch), value)
858
+
859
+ def add_embedding_scale(self, value: float) -> None:
860
+ self.add_float41(Keys.LLM.EMBEDDING_SCALE.format(arch=self.arch), value)
861
+
862
+ def add_wkv_head_size(self, size: int) -> None:
863
+ self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
864
+
865
+ def add_token_shift_count(self, count: int) -> None:
866
+ self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
867
+
868
+ def add_interleave_moe_layer_step(self, value: int) -> None:
869
+ self.add_uint32(Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value)
870
+
871
+ def add_layer_norm_eps(self, value: float) -> None:
872
+ self.add_float41(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
873
+
874
+ def add_layer_norm_rms_eps(self, value: float) -> None:
875
+ self.add_float41(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
876
+
877
+ def add_group_norm_eps(self, value: float) -> None:
878
+ self.add_float41(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value)
879
+
880
+ def add_group_norm_groups(self, value: int) -> None:
881
+ self.add_uint32(Keys.Attention.GROUPNORM_GROUPS.format(arch=self.arch), value)
882
+
883
+ def add_causal_attention(self, value: bool) -> None:
884
+ self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
885
+
886
+ def add_q_lora_rank(self, length: int) -> None:
887
+ self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length)
888
+
889
+ def add_kv_lora_rank(self, length: int) -> None:
890
+ self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
891
+
892
+ def add_decay_lora_rank(self, length: int) -> None:
893
+ self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
894
+
895
+ def add_iclr_lora_rank(self, length: int) -> None:
896
+ self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
897
+
898
+ def add_value_residual_mix_lora_rank(self, length: int) -> None:
899
+ self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
900
+
901
+ def add_rope_freq_base_swa(self, value: float) -> None:
902
+ self.add_float41(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value)
903
+
904
+ def add_gate_lora_rank(self, length: int) -> None:
905
+ self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
906
+
907
+ def add_relative_attn_buckets_count(self, value: int) -> None:
908
+ self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
909
+
910
+ def add_sliding_window(self, value: int) -> None:
911
+ self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
912
+
913
+ def add_attention_scale(self, value: float) -> None:
914
+ self.add_float41(Keys.Attention.SCALE.format(arch=self.arch), value)
915
+
916
+ def add_attn_output_scale(self, value: float) -> None:
917
+ self.add_float41(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
918
+
919
+ def add_attn_temperature_length(self, value: int) -> None:
920
+ self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
921
+
922
+ def add_attn_temperature_scale(self, value: float) -> None:
923
+ self.add_float41(Keys.Attention.TEMPERATURE_SCALE.format(arch=self.arch), value)
924
+
925
+ def add_pooling_type(self, value: PoolingType) -> None:
926
+ self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
927
+
928
+ def add_num_deepstack_layers(self, count: int) -> None:
929
+ self.add_uint32(Keys.LLM.NUM_DEEPSTACK_LAYERS.format(arch=self.arch), count)
930
+
931
+ def add_rope_dimension_count(self, count: int) -> None:
932
+ self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
933
+
934
+ def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
935
+ self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)
936
+
937
+ def add_rope_freq_base(self, value: float) -> None:
938
+ self.add_float41(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
939
+
940
+ def add_rope_scaling_type(self, value: RopeScalingType) -> None:
941
+ self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)
942
+
943
+ def add_rope_scaling_factor(self, value: float) -> None:
944
+ self.add_float41(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
945
+
946
+ def add_rope_scaling_attn_factors(self, value: float) -> None:
947
+ self.add_float41(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
948
+
949
+ def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
950
+ self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
951
+
952
+ def add_rope_scaling_finetuned(self, value: bool) -> None:
953
+ self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
954
+
955
+ def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
956
+ self.add_float41(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
957
+
958
+ def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
959
+ self.add_float41(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
960
+
961
+ def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
962
+ self.add_float41(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
963
+
964
+ def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
965
+ self.add_float41(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
966
+
967
+ def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
968
+ self.add_float41(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
969
+
970
+ def add_ssm_conv_kernel(self, value: int) -> None:
971
+ self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
972
+
973
+ def add_ssm_inner_size(self, value: int) -> None:
974
+ self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value)
975
+
976
+ def add_ssm_state_size(self, value: int) -> None:
977
+ self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value)
978
+
979
+ def add_ssm_time_step_rank(self, value: int) -> None:
980
+ self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
981
+
982
+ def add_ssm_group_count(self, value: int) -> None:
983
+ self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value)
984
+
985
+ def add_ssm_dt_b_c_rms(self, value: bool) -> None:
986
+ self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
987
+
988
+ def add_tokenizer_model(self, model: str) -> None:
989
+ self.add_string(Keys.Tokenizer.MODEL, model)
990
+
991
+ def add_tokenizer_pre(self, pre: str) -> None:
992
+ self.add_string(Keys.Tokenizer.PRE, pre)
993
+
994
+ def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
995
+ self.add_array(Keys.Tokenizer.LIST, tokens)
996
+
997
+ def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
998
+ self.add_array(Keys.Tokenizer.MERGES, merges)
999
+
1000
+ def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
1001
+ self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
1002
+
1003
+ def add_token_type_count(self, value: int) -> None:
1004
+ self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)
1005
+
1006
+ def add_token_scores(self, scores: Sequence[float]) -> None:
1007
+ self.add_array(Keys.Tokenizer.SCORES, scores)
1008
+
1009
+ def add_bos_token_id(self, id: int) -> None:
1010
+ self.add_uint32(Keys.Tokenizer.BOS_ID, id)
1011
+
1012
+ def add_eos_token_id(self, id: int) -> None:
1013
+ self.add_uint32(Keys.Tokenizer.EOS_ID, id)
1014
+
1015
+ def add_unk_token_id(self, id: int) -> None:
1016
+ self.add_uint32(Keys.Tokenizer.UNK_ID, id)
1017
+
1018
+ def add_sep_token_id(self, id: int) -> None:
1019
+ self.add_uint32(Keys.Tokenizer.SEP_ID, id)
1020
+
1021
+ def add_pad_token_id(self, id: int) -> None:
1022
+ self.add_uint32(Keys.Tokenizer.PAD_ID, id)
1023
+
1024
+ def add_mask_token_id(self, id: int) -> None:
1025
+ self.add_uint32(Keys.Tokenizer.MASK_ID, id)
1026
+
1027
+ def add_add_bos_token(self, value: bool) -> None:
1028
+ self.add_bool(Keys.Tokenizer.ADD_BOS, value)
1029
+
1030
+ def add_add_eos_token(self, value: bool) -> None:
1031
+ self.add_bool(Keys.Tokenizer.ADD_EOS, value)
1032
+
1033
+ def add_add_sep_token(self, value: bool) -> None:
1034
+ self.add_bool(Keys.Tokenizer.ADD_SEP, value)
1035
+
1036
+ def add_add_space_prefix(self, value: bool) -> None:
1037
+ self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
1038
+
1039
+ def add_remove_extra_whitespaces(self, value: bool) -> None:
1040
+ self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value)
1041
+
1042
+ def add_precompiled_charsmap(self, charsmap: bytes) -> None:
1043
+ self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
1044
+
1045
+ def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
1046
+ if not isinstance(value, str):
1047
+ template_default = None
1048
+ template_names = set()
1049
+
1050
+ for choice in value:
1051
+ name = choice.get('name', '')
1052
+ template = choice.get('template')
1053
+
1054
+ # Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it
1055
+ name = ''.join((c if c in ascii_letters + digits else '_' for c in name))
1056
+
1057
+ if name and template is not None:
1058
+ if name == 'default':
1059
+ template_default = template
1060
+ else:
1061
+ template_names.add(name)
1062
+ self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template)
1063
+
1064
+ if template_names:
1065
+ self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names))
1066
+
1067
+ if template_default is None:
1068
+ return
1069
+
1070
+ value = template_default
1071
+
1072
+ self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
1073
+
1074
+ def add_eot_token_id(self, id: int) -> None:
1075
+ self.add_uint32(Keys.Tokenizer.EOT_ID, id)
1076
+
1077
+ def add_eom_token_id(self, id: int) -> None:
1078
+ self.add_uint32(Keys.Tokenizer.EOM_ID, id)
1079
+
1080
+ def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
1081
+ self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
1082
+
1083
+ # for vision models
1084
+
1085
+ def add_clip_has_vision_encoder(self, value: bool) -> None:
1086
+ self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value)
1087
+
1088
+ def add_clip_has_audio_encoder(self, value: bool) -> None:
1089
+ self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value)
1090
+
1091
+ def add_clip_projector_type(self, value: str) -> None:
1092
+ self.add_string(Keys.Clip.PROJECTOR_TYPE, value)
1093
+
1094
+ def add_clip_vision_projector_type(self, value: str) -> None:
1095
+ self.add_string(Keys.ClipVision.PROJECTOR_TYPE, value)
1096
+
1097
+ def add_vision_projection_dim(self, value: int) -> None:
1098
+ self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
1099
+
1100
+ def add_vision_patch_size(self, value: int) -> None:
1101
+ self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
1102
+
1103
+ def add_vision_embedding_length(self, value: int) -> None:
1104
+ self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value)
1105
+
1106
+ def add_vision_feed_forward_length(self, value: int) -> None:
1107
+ self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value)
1108
+
1109
+ def add_vision_block_count(self, value: int) -> None:
1110
+ self.add_uint32(Keys.ClipVision.BLOCK_COUNT, value)
1111
+
1112
+ def add_vision_head_count(self, value: int) -> None:
1113
+ self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
1114
+
1115
+ def add_vision_attention_layernorm_eps(self, value: float) -> None:
1116
+ self.add_float41(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
1117
+
1118
+ def add_vision_image_size(self, value: int) -> None:
1119
+ self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value)
1120
+
1121
+ def add_vision_preproc_image_size(self, value: int) -> None:
1122
+ self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value)
1123
+
1124
+ def add_vision_image_mean(self, values: Sequence[float]) -> None:
1125
+ self.add_array(Keys.ClipVision.IMAGE_MEAN, values)
1126
+
1127
+ def add_vision_image_std(self, values: Sequence[float]) -> None:
1128
+ self.add_array(Keys.ClipVision.IMAGE_STD, values)
1129
+
1130
+ def add_vision_spatial_merge_size(self, value: int) -> None:
1131
+ self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value)
1132
+
1133
+ def add_vision_use_gelu(self, value: bool) -> None:
1134
+ self.add_bool(Keys.ClipVision.USE_GELU, value)
1135
+
1136
+ def add_vision_use_silu(self, value: bool) -> None:
1137
+ self.add_bool(Keys.ClipVision.USE_SILU, value)
1138
+
1139
+ def add_vision_projector_scale_factor(self, value: int) -> None:
1140
+ self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
1141
+
1142
+ def add_vision_n_wa_pattern(self, value: int) -> None:
1143
+ """Add window attention pattern interval for vision models.
1144
+
1145
+ This defines the pattern interval for window attention vs full attention layers.
1146
+ For example, if n_wa_pattern=4, then layers 3, 7, 11, ... use full attention,
1147
+ while other layers use window attention.
1148
+
1149
+ Used by models like Qwen2.5-VL where full attention layers follow a regular pattern.
1150
+ """
1151
+ self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
1152
+
1153
+ def add_vision_wa_layer_indexes(self, layers: Sequence[int]) -> None:
1154
+ """Add explicit layer indexes that use full attention in vision models.
1155
+
1156
+ This specifies the exact layer indices (0-based) that should use full attention
1157
+ instead of window attention. All other layers will use window attention.
1158
+
1159
+ Args:
1160
+ layers: List of layer indices that use full attention (e.g., [3, 7, 11, 15])
1161
+
1162
+ Used by models like YoutuVL where full attention layers are explicitly specified
1163
+ rather than following a regular pattern.
1164
+
1165
+ Difference from add_vision_n_wa_pattern:
1166
+ - n_wa_pattern: Defines a regular interval pattern (every Nth layer uses full attention)
1167
+ - wa_layer_indexes: Explicitly lists which layers use full attention (irregular pattern)
1168
+ """
1169
+ self.add_array(Keys.ClipVision.WA_LAYER_INDEXES, layers)
1170
+
1171
+ def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
1172
+ self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
1173
+
1174
+ def add_vision_window_size(self, value: int) -> None:
1175
+ self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value)
1176
+
1177
+ # audio models
1178
+
1179
+ def add_clip_audio_projector_type(self, value: str) -> None:
1180
+ self.add_string(Keys.ClipAudio.PROJECTOR_TYPE, value)
1181
+
1182
+ def add_audio_projection_dim(self, value: int) -> None:
1183
+ self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value)
1184
+
1185
+ def add_audio_embedding_length(self, value: int) -> None:
1186
+ self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value)
1187
+
1188
+ def add_audio_feed_forward_length(self, value: int) -> None:
1189
+ self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value)
1190
+
1191
+ def add_audio_block_count(self, value: int) -> None:
1192
+ self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value)
1193
+
1194
+ def add_audio_head_count(self, value: int) -> None:
1195
+ self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value)
1196
+
1197
+ def add_audio_attention_layernorm_eps(self, value: float) -> None:
1198
+ self.add_float41(Keys.ClipAudio.Attention.LAYERNORM_EPS, value)
1199
+
1200
+ def add_audio_num_mel_bins(self, value: int) -> None:
1201
+ self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value)
1202
+
1203
+ def add_audio_stack_factor(self, value: int) -> None:
1204
+ self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
1205
+
1206
+ def add_xielu_alpha_p(self, values: Sequence[float]):
1207
+ self.add_array(Keys.xIELU.ALPHA_P, values)
1208
+
1209
+ def add_xielu_alpha_n(self, values: Sequence[float]):
1210
+ self.add_array(Keys.xIELU.ALPHA_N, values)
1211
+
1212
+ def add_xielu_beta(self, values: Sequence[float]):
1213
+ self.add_array(Keys.xIELU.BETA, values)
1214
+
1215
+ def add_xielu_eps(self, values: Sequence[float]):
1216
+ self.add_array(Keys.xIELU.EPS, values)
1217
+
1218
+ # diffusion models
1219
+
1220
+ def add_diffusion_shift_logits(self, value: bool) -> None:
1221
+ self.add_bool(Keys.Diffusion.SHIFT_LOGITS, value)
1222
+
1223
+ def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
1224
+ pack_prefix = ''
1225
+ if not skip_pack_prefix:
1226
+ pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
1227
+ return struct.pack(f'{pack_prefix}{fmt}', value)
1228
+
1229
+ def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes:
1230
+ kv_data = bytearray()
1231
+
1232
+ if add_vtype:
1233
+ kv_data += self._pack("I", vtype)
1234
+
1235
+ pack_fmt = self._simple_value_packing.get(vtype)
1236
+ if pack_fmt is not None:
1237
+ kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
1238
+ elif vtype == GGUFValueType.STRING:
1239
+ encoded_val = val.encode("utf-8") if isinstance(val, str) else val
1240
+ kv_data += self._pack("Q", len(encoded_val))
1241
+ kv_data += encoded_val
1242
+ elif vtype == GGUFValueType.ARRAY:
1243
+
1244
+ if not isinstance(val, Sequence):
1245
+ raise ValueError("Invalid GGUF metadata array, expecting sequence")
1246
+
1247
+ if len(val) == 0:
1248
+ raise ValueError("Invalid GGUF metadata array. Empty array")
1249
+
1250
+ if sub_type is not None:
1251
+ ltype = sub_type
1252
+ elif isinstance(val, bytes):
1253
+ ltype = GGUFValueType.UINT8
1254
+ else:
1255
+ ltype = GGUFValueType.get_type(val[0])
1256
+ if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
1257
+ raise ValueError("All items in a GGUF array should be of the same type")
1258
+ kv_data += self._pack("I", ltype)
1259
+ kv_data += self._pack("Q", len(val))
1260
+ for item in val:
1261
+ kv_data += self._pack_val(item, ltype, add_vtype=False)
1262
+ else:
1263
+ raise ValueError("Invalid GGUF metadata value type or value")
1264
+
1265
+ return kv_data
1266
+
1267
+ @staticmethod
1268
+ def format_n_bytes_to_str(num: int) -> str:
1269
+ if num == 0:
1270
+ return "negligible - metadata only"
1271
+ fnum = float(num)
1272
+ for unit in ("", "K", "M", "G"):
1273
+ if abs(fnum) < 1000.0:
1274
+ return f"{fnum:3.1f}{unit}"
1275
+ fnum /= 1000.0
1276
+ return f"{fnum:.1f}T - over 1TB, split recommended"
quants.py ADDED
@@ -0,0 +1,1443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #BRUNO BECKER / OFFELLIA 2026
2
+ #brunoconta1980@gmail.com
3
+ #brunoconta1980@hotmail.com
4
+ # X @Brunoxuser
5
+ #3301
6
+
7
+ #* ᛒ (B), ᛖ (E), ᚳ (C), ᚳ (K), ᛖ (E), ᚱ (R).
8
+ #* 17, 18, 5, 5, 18, 4.
9
+ #* ᚩ (O), ᚠ (F), ᚠ (F), ᛖ (E), ᛚ (L), ᛚ (L), ᛁ (I), ᚪ (A).
10
+ #* 3, 0, 0, 18, 20, 20, 10, 24.
11
+
12
+ from __future__ import annotations
13
+ from abc import ABC, abstractmethod
14
+ from typing import Any, Callable, Sequence
15
+ from math import log2, ceil
16
+
17
+ from numpy.typing import DTypeLike
18
+
19
+ from .constants import GGML_QUANT_SIZES, GGMLQuantizationType, QK_K
20
+ from .lazy import LazyNumpyTensor
21
+
22
+ import numpy as np
23
+
24
+
25
+ def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
26
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
27
+ if shape[-1] % block_size != 0:
28
+ raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
29
+ return (*shape[:-1], shape[-1] // block_size * type_size)
30
+
31
+
32
+ def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
33
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
34
+ if shape[-1] % type_size != 0:
35
+ raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})")
36
+ return (*shape[:-1], shape[-1] // type_size * block_size)
37
+
38
+
39
+ # This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
40
+ def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
41
+ rows = arr.reshape((-1, arr.shape[-1]))
42
+ osize = 1
43
+ for dim in oshape:
44
+ osize *= dim
45
+ out = np.empty(shape=osize, dtype=otype)
46
+ # compute over groups of 16 rows (arbitrary, but seems good for performance)
47
+ n_groups = (rows.shape[0] // 16) or 1
48
+ np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
49
+ return out.reshape(oshape)
50
+
51
+
52
+ # round away from zero
53
+ # ref: https://stackoverflow.com/a/59143326/22827863
54
+ def np_roundf(n: np.ndarray) -> np.ndarray:
55
+ a = abs(n)
56
+ floored = np.floor(a)
57
+ b = floored + np.floor(2 * (a - floored))
58
+ return np.sign(n) * b
59
+
60
+
61
+ class QuantError(Exception): ...
62
+
63
+
64
+ _type_traits: dict[GGMLQuantizationType, type[__Quant]] = {}
65
+
66
+
67
+ def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
68
+ if qtype == GGMLQuantizationType.F32:
69
+ return data.astype(np.float32, copy=False)
70
+ elif qtype == GGMLQuantizationType.F16:
71
+ return data.astype(np.float16, copy=False)
72
+ elif (q := _type_traits.get(qtype)) is not None:
73
+ return q.quantize(data)
74
+ else:
75
+ raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented")
76
+
77
+
78
+ def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
79
+ if qtype == GGMLQuantizationType.F32:
80
+ return data.view(np.float32)
81
+ elif qtype == GGMLQuantizationType.F16:
82
+ return data.view(np.float16).astype(np.float32)
83
+ elif (q := _type_traits.get(qtype)) is not None:
84
+ return q.dequantize(data)
85
+ else:
86
+ raise NotImplementedError(f"Dequantization for {qtype.name} is not yet implemented")
87
+
88
+
89
+ class __Quant(ABC):
90
+ qtype: GGMLQuantizationType
91
+ block_size: int
92
+ type_size: int
93
+
94
+ grid: np.ndarray[Any, np.dtype[np.float32]] | None = None
95
+ grid_shape: tuple[int, int] = (0, 0)
96
+ grid_map: tuple[int | float, ...] = ()
97
+ grid_hex: bytes | None = None
98
+
99
+ def __init__(self):
100
+ return TypeError("Quant conversion classes can't have instances")
101
+
102
+ def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
103
+ cls.qtype = qtype
104
+ cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
105
+ cls.__quantize_lazy = LazyNumpyTensor._wrap_fn(
106
+ cls.__quantize_array,
107
+ meta_noop=(np.uint8, cls.__shape_to_bytes)
108
+ )
109
+ cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn(
110
+ cls.__dequantize_array,
111
+ meta_noop=(np.float32, cls.__shape_from_bytes)
112
+ )
113
+ assert qtype not in _type_traits
114
+ _type_traits[qtype] = cls
115
+
116
+ @classmethod
117
+ def init_grid(cls):
118
+ if cls.grid is not None or cls.grid_hex is None:
119
+ return
120
+
121
+ bits_per_elem = ceil(log2(len(cls.grid_map)))
122
+ assert bits_per_elem != 0, cls.qtype.name
123
+ elems_per_byte = 8 // bits_per_elem
124
+
125
+ grid = np.frombuffer(cls.grid_hex, dtype=np.uint8)
126
+ # decode hexadecimal chars from grid
127
+ grid = grid.reshape((-1, 2))
128
+ grid = (np.where(grid > 0x40, grid + 9, grid) & 0x0F) << np.array([4, 0], dtype=np.uint8).reshape((1, 2))
129
+ grid = grid[..., 0] | grid[..., 1]
130
+ # unpack the grid values
131
+ grid = grid.reshape((-1, 1)) >> np.array([i for i in range(0, 8, 8 // elems_per_byte)], dtype=np.uint8).reshape((1, elems_per_byte))
132
+ grid = (grid & ((1 << bits_per_elem) - 1)).reshape((-1, 1))
133
+ grid_map = np.array(cls.grid_map, dtype=np.float32).reshape((1, -1))
134
+ grid = np.take_along_axis(grid_map, grid, axis=-1)
135
+ cls.grid = grid.reshape((1, 1, *cls.grid_shape))
136
+
137
+ @classmethod
138
+ @abstractmethod
139
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
140
+ raise NotImplementedError
141
+
142
+ @classmethod
143
+ @abstractmethod
144
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
145
+ raise NotImplementedError
146
+
147
+ @classmethod
148
+ def quantize_rows(cls, rows: np.ndarray) -> np.ndarray:
149
+ rows = rows.astype(np.float32, copy=False)
150
+ shape = rows.shape
151
+ n_blocks = rows.size // cls.block_size
152
+ blocks = rows.reshape((n_blocks, cls.block_size))
153
+ blocks = cls.quantize_blocks(blocks)
154
+ assert blocks.dtype == np.uint8
155
+ assert blocks.shape[-1] == cls.type_size
156
+ return blocks.reshape(cls.__shape_to_bytes(shape))
157
+
158
+ @classmethod
159
+ def dequantize_rows(cls, rows: np.ndarray) -> np.ndarray:
160
+ rows = rows.view(np.uint8)
161
+ shape = rows.shape
162
+ n_blocks = rows.size // cls.type_size
163
+ blocks = rows.reshape((n_blocks, cls.type_size))
164
+ blocks = cls.dequantize_blocks(blocks)
165
+ assert blocks.dtype == np.float32
166
+ assert blocks.shape[-1] == cls.block_size
167
+ return blocks.reshape(cls.__shape_from_bytes(shape))
168
+
169
+ @classmethod
170
+ def __shape_to_bytes(cls, shape: Sequence[int]):
171
+ return quant_shape_to_byte_shape(shape, cls.qtype)
172
+
173
+ @classmethod
174
+ def __shape_from_bytes(cls, shape: Sequence[int]):
175
+ return quant_shape_from_byte_shape(shape, cls.qtype)
176
+
177
+ @classmethod
178
+ def __quantize_array(cls, array: np.ndarray) -> np.ndarray:
179
+ return _apply_over_grouped_rows(cls.quantize_rows, arr=array, otype=np.uint8, oshape=cls.__shape_to_bytes(array.shape))
180
+
181
+ @classmethod
182
+ def __dequantize_array(cls, array: np.ndarray) -> np.ndarray:
183
+ cls.init_grid()
184
+ return _apply_over_grouped_rows(cls.dequantize_rows, arr=array, otype=np.float32, oshape=cls.__shape_from_bytes(array.shape))
185
+
186
+ @classmethod
187
+ def __quantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
188
+ pass
189
+
190
+ @classmethod
191
+ def __dequantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
192
+ pass
193
+
194
+ @classmethod
195
+ def can_quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> bool:
196
+ return tensor.shape[-1] % cls.block_size == 0
197
+
198
+ @classmethod
199
+ def quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
200
+ if not cls.can_quantize(tensor):
201
+ raise QuantError(f"Can't quantize tensor with shape {tensor.shape} to {cls.qtype.name}")
202
+ if isinstance(tensor, LazyNumpyTensor):
203
+ return cls.__quantize_lazy(tensor)
204
+ else:
205
+ return cls.__quantize_array(tensor)
206
+
207
+ @classmethod
208
+ def dequantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
209
+ if isinstance(tensor, LazyNumpyTensor):
210
+ return cls.__dequantize_lazy(tensor)
211
+ else:
212
+ return cls.__dequantize_array(tensor)
213
+
214
+
215
+ class BF16(__Quant, qtype=GGMLQuantizationType.BF16):
216
+ @classmethod
217
+ # same as ggml_compute_fp32_to_bf16 in ggml-impl.h
218
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
219
+ n = blocks.view(np.uint32)
220
+ # force nan to quiet
221
+ n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
222
+ # round to nearest even
223
+ n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
224
+ return n.astype(np.uint16).view(np.uint8)
225
+
226
+ @classmethod
227
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
228
+ return (blocks.view(np.int16).astype(np.int32) << 16).view(np.float32)
229
+
230
+
231
+ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
232
+ @classmethod
233
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
234
+ n_blocks = blocks.shape[0]
235
+
236
+ imax = abs(blocks).argmax(axis=-1, keepdims=True)
237
+ max = np.take_along_axis(blocks, imax, axis=-1)
238
+
239
+ d = max / -8
240
+ with np.errstate(divide="ignore"):
241
+ id = np.where(d == 0, 0, 1 / d)
242
+ qs = np.trunc((blocks * id) + np.float32(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
243
+
244
+ qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
245
+ qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
246
+
247
+ d = d.astype(np.float16).view(np.uint8)
248
+
249
+ return np.concatenate([d, qs], axis=-1)
250
+
251
+ @classmethod
252
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
253
+ n_blocks = blocks.shape[0]
254
+
255
+ d, qs = np.hsplit(blocks, [2])
256
+
257
+ d = d.view(np.float16).astype(np.float32)
258
+
259
+ qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
260
+ qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1)).astype(np.int8) - np.int8(8)
261
+
262
+ return (d * qs.astype(np.float32))
263
+
264
+
265
+ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
266
+ @classmethod
267
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
268
+ n_blocks = blocks.shape[0]
269
+
270
+ max = blocks.max(axis=-1, keepdims=True)
271
+ min = blocks.min(axis=-1, keepdims=True)
272
+
273
+ d = (max - min) / 15
274
+ with np.errstate(divide="ignore"):
275
+ id = np.where(d == 0, 0, 1 / d)
276
+ qs = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
277
+
278
+ qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
279
+ qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
280
+
281
+ d = d.astype(np.float16).view(np.uint8)
282
+ m = min.astype(np.float16).view(np.uint8)
283
+
284
+ return np.concatenate([d, m, qs], axis=-1)
285
+
286
+ @classmethod
287
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
288
+ n_blocks = blocks.shape[0]
289
+
290
+ d, rest = np.hsplit(blocks, [2])
291
+ m, qs = np.hsplit(rest, [2])
292
+
293
+ d = d.view(np.float16).astype(np.float32)
294
+ m = m.view(np.float16).astype(np.float32)
295
+
296
+ qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
297
+ qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1)).astype(np.float32)
298
+
299
+ return (d * qs) + m
300
+
301
+
302
+ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
303
+ @classmethod
304
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
305
+ n_blocks = blocks.shape[0]
306
+
307
+ imax = abs(blocks).argmax(axis=-1, keepdims=True)
308
+ max = np.take_along_axis(blocks, imax, axis=-1)
309
+
310
+ d = max / -16
311
+ with np.errstate(divide="ignore"):
312
+ id = np.where(d == 0, 0, 1 / d)
313
+ q = np.trunc((blocks * id) + np.float32(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
314
+
315
+ qs = q.reshape((n_blocks, 2, cls.block_size // 2))
316
+ qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
317
+
318
+ qh = np.packbits(q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little").reshape(n_blocks, 4)
319
+
320
+ d = d.astype(np.float16).view(np.uint8)
321
+
322
+ return np.concatenate([d, qh, qs], axis=-1)
323
+
324
+ @classmethod
325
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
326
+ n_blocks = blocks.shape[0]
327
+
328
+ d, rest = np.hsplit(blocks, [2])
329
+ qh, qs = np.hsplit(rest, [4])
330
+
331
+ d = d.view(np.float16).astype(np.float32)
332
+ qh = qh.view(np.uint32)
333
+
334
+ qh = qh.reshape((n_blocks, 1)) >> np.array([i for i in range(32)], dtype=np.uint32).reshape((1, 32))
335
+ ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
336
+ qh = (qh & np.uint32(0x01)).astype(np.uint8)
337
+ ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1))
338
+
339
+ qs = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(16)
340
+
341
+ return (d * qs.astype(np.float32))
342
+
343
+
344
+ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
345
+ @classmethod
346
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
347
+ n_blocks = blocks.shape[0]
348
+
349
+ max = blocks.max(axis=-1, keepdims=True)
350
+ min = blocks.min(axis=-1, keepdims=True)
351
+
352
+ d = (max - min) / 31
353
+ with np.errstate(divide="ignore"):
354
+ id = np.where(d == 0, 0, 1 / d)
355
+ q = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
356
+
357
+ qs = q.reshape((n_blocks, 2, cls.block_size // 2))
358
+ qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
359
+
360
+ qh = np.packbits(q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little").reshape(n_blocks, 4)
361
+
362
+ d = d.astype(np.float16).view(np.uint8)
363
+ m = min.astype(np.float16).view(np.uint8)
364
+
365
+ return np.concatenate([d, m, qh, qs], axis=-1)
366
+
367
+ @classmethod
368
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
369
+ n_blocks = blocks.shape[0]
370
+
371
+ d, rest = np.hsplit(blocks, [2])
372
+ m, rest = np.hsplit(rest, [2])
373
+ qh, qs = np.hsplit(rest, [4])
374
+
375
+ d = d.view(np.float16).astype(np.float32)
376
+ m = m.view(np.float16).astype(np.float32)
377
+ qh = qh.view(np.uint32)
378
+
379
+ qh = qh.reshape((n_blocks, 1)) >> np.array([i for i in range(32)], dtype=np.uint32).reshape((1, 32))
380
+ ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
381
+ qh = (qh & np.uint32(0x01)).astype(np.uint8)
382
+ ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1))
383
+
384
+ qs = (ql | (qh << np.uint8(4))).astype(np.float32)
385
+
386
+ return (d * qs) + m
387
+
388
+
389
+ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
390
+ @classmethod
391
+ # Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c
392
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
393
+
394
+ d = abs(blocks).max(axis=1, keepdims=True) / 127
395
+ with np.errstate(divide="ignore"):
396
+ id = np.where(d == 0, 0, 1 / d)
397
+ qs = np_roundf(blocks * id)
398
+
399
+ # (n_blocks, 2)
400
+ d = d.astype(np.float16).view(np.uint8)
401
+ # (n_blocks, block_size)
402
+ qs = qs.astype(np.int8).view(np.uint8)
403
+
404
+ return np.concatenate([d, qs], axis=1)
405
+
406
+ @classmethod
407
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
408
+ d, x = np.split(blocks, [2], axis=1)
409
+ d = d.view(np.float16).astype(np.float32)
410
+ x = x.view(np.int8).astype(np.float32)
411
+
412
+ return (x * d)
413
+
414
+
415
+ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
416
+ @classmethod
417
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
418
+ n_blocks = blocks.shape[0]
419
+
420
+ scales, rest = np.hsplit(blocks, [QK_K // 16])
421
+ qs, rest = np.hsplit(rest, [QK_K // 4])
422
+ d, dmin = np.hsplit(rest, [2])
423
+
424
+ d = d.view(np.float16).astype(np.float32)
425
+ dmin = dmin.view(np.float16).astype(np.float32)
426
+
427
+ # (n_blocks, 16, 1)
428
+ dl = (d * (scales & 0xF).astype(np.float32)).reshape((n_blocks, QK_K // 16, 1))
429
+ ml = (dmin * (scales >> 4).astype(np.float32)).reshape((n_blocks, QK_K // 16, 1))
430
+
431
+ shift = np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
432
+
433
+ qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & np.uint8(3)
434
+
435
+ qs = qs.reshape((n_blocks, QK_K // 16, 16)).astype(np.float32)
436
+
437
+ qs = dl * qs - ml
438
+
439
+ return qs.reshape((n_blocks, -1))
440
+
441
+
442
+ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
443
+ @classmethod
444
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
445
+ n_blocks = blocks.shape[0]
446
+
447
+ hmask, rest = np.hsplit(blocks, [QK_K // 8])
448
+ qs, rest = np.hsplit(rest, [QK_K // 4])
449
+ scales, d = np.hsplit(rest, [12])
450
+
451
+ d = d.view(np.float16).astype(np.float32)
452
+
453
+ # The scales are packed at 6-bit each in this pattern:
454
+ # 0: IIIIAAAA
455
+ # 1: JJJJBBBB
456
+ # 2: KKKKCCCC
457
+ # 3: LLLLDDDD
458
+ # 4: MMMMEEEE
459
+ # 5: NNNNFFFF
460
+ # 6: OOOOGGGG
461
+ # 7: PPPPHHHH
462
+ # 8: MMIIEEAA
463
+ # 9: NNJJFFBB
464
+ # 10: OOKKGGCC
465
+ # 11: PPLLHHDD
466
+ lscales, hscales = np.hsplit(scales, [8])
467
+ lscales = lscales.reshape((n_blocks, 1, 8)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
468
+ lscales = lscales.reshape((n_blocks, 16))
469
+ hscales = hscales.reshape((n_blocks, 1, 4)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 4, 1))
470
+ hscales = hscales.reshape((n_blocks, 16))
471
+ scales = (lscales & np.uint8(0x0F)) | ((hscales & np.uint8(0x03)) << np.uint8(4))
472
+ scales = (scales.astype(np.int8) - np.int8(32)).astype(np.float32)
473
+
474
+ dl = (d * scales).reshape((n_blocks, 16, 1))
475
+
476
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
477
+ qh = hmask.reshape(n_blocks, -1, 1, 32) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8, 1))
478
+ ql = ql.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(3)
479
+ qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(1))
480
+ qh = qh ^ np.uint8(1) # strangely, the offset is zero when the bitmask is 1
481
+ q = (ql.astype(np.int8) - (qh << np.uint8(2)).astype(np.int8)).astype(np.float32)
482
+
483
+ return (dl * q).reshape((n_blocks, QK_K))
484
+
485
+
486
+ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
487
+ K_SCALE_SIZE = 12
488
+
489
+ @staticmethod
490
+ def get_scale_min(scales: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
491
+ n_blocks = scales.shape[0]
492
+ scales = scales.view(np.uint8)
493
+ ### Unpacking the following: ###
494
+ # 0 EEAAAAAA
495
+ # 1 FFBBBBBB
496
+ # 2 GGCCCCCC
497
+ # 3 HHDDDDDD
498
+ # 4 eeaaaaaa
499
+ # 5 ffbbbbbb
500
+ # 6 ggcccccc
501
+ # 7 hhdddddd
502
+ # 8 eeeeEEEE
503
+ # 9 ffffFFFF
504
+ # 10 ggggGGGG
505
+ # 11 hhhhHHHH
506
+ scales = scales.reshape((n_blocks, 3, 4))
507
+ d, m, m_d = np.split(scales, 3, axis=-2)
508
+
509
+ sc = np.concatenate([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], axis=-1)
510
+ min = np.concatenate([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], axis=-1)
511
+
512
+ return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
513
+
514
+ @classmethod
515
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
516
+ n_blocks = blocks.shape[0]
517
+
518
+ d, rest = np.hsplit(blocks, [2])
519
+ dmin, rest = np.hsplit(rest, [2])
520
+ scales, qs = np.hsplit(rest, [cls.K_SCALE_SIZE])
521
+
522
+ d = d.view(np.float16).astype(np.float32)
523
+ dmin = dmin.view(np.float16).astype(np.float32)
524
+
525
+ sc, m = Q4_K.get_scale_min(scales)
526
+
527
+ d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1))
528
+ dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1))
529
+
530
+ qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
531
+ qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 32)).astype(np.float32)
532
+
533
+ return (d * qs - dm).reshape((n_blocks, QK_K))
534
+
535
+
536
+ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
537
+ @classmethod
538
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
539
+ n_blocks = blocks.shape[0]
540
+
541
+ d, rest = np.hsplit(blocks, [2])
542
+ dmin, rest = np.hsplit(rest, [2])
543
+ scales, rest = np.hsplit(rest, [Q4_K.K_SCALE_SIZE])
544
+ qh, qs = np.hsplit(rest, [QK_K // 8])
545
+
546
+ d = d.view(np.float16).astype(np.float32)
547
+ dmin = dmin.view(np.float16).astype(np.float32)
548
+
549
+ sc, m = Q4_K.get_scale_min(scales)
550
+
551
+ d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1))
552
+ dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1))
553
+
554
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
555
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8, 1))
556
+ ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32))
557
+ qh = (qh & np.uint8(0x01)).reshape((n_blocks, -1, 32))
558
+ q = (ql | (qh << np.uint8(4))).astype(np.float32)
559
+
560
+ return (d * q - dm).reshape((n_blocks, QK_K))
561
+
562
+
563
+ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
564
+ @classmethod
565
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
566
+ n_blocks = blocks.shape[0]
567
+
568
+ ql, rest = np.hsplit(blocks, [QK_K // 2])
569
+ qh, rest = np.hsplit(rest, [QK_K // 4])
570
+ scales, d = np.hsplit(rest, [QK_K // 16])
571
+
572
+ scales = scales.view(np.int8).astype(np.float32)
573
+ d = d.view(np.float16).astype(np.float32)
574
+ d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
575
+
576
+ ql = ql.reshape((n_blocks, -1, 1, 64)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
577
+ ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32))
578
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
579
+ qh = (qh & np.uint8(0x03)).reshape((n_blocks, -1, 32))
580
+ q = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(32)
581
+ q = q.reshape((n_blocks, QK_K // 16, -1)).astype(np.float32)
582
+
583
+ return (d * q).reshape((n_blocks, QK_K))
584
+
585
+
586
+ class TQ1_0(__Quant, qtype=GGMLQuantizationType.TQ1_0):
587
+ @classmethod
588
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
589
+ n_blocks = blocks.shape[0]
590
+
591
+ d = abs(blocks).max(axis=-1, keepdims=True)
592
+ with np.errstate(divide="ignore"):
593
+ id = np.where(d == 0, 0, 1 / d)
594
+ qs = np_roundf(blocks * id)
595
+ qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8)
596
+
597
+ qs0, qs1, qh = qs[..., :(32 * 5)], qs[..., (32 * 5):(48 * 5)], qs[..., (48 * 5):]
598
+ qs0 = qs0.reshape((n_blocks, -1, 5, 32)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1))
599
+ qs0 = np.sum(qs0, axis=-2).reshape((n_blocks, -1))
600
+ qs1 = qs1.reshape((n_blocks, -1, 5, 16)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1))
601
+ qs1 = np.sum(qs1, axis=-2).reshape((n_blocks, -1))
602
+ qh = qh.reshape((n_blocks, -1, 4, 4)) * np.array([81, 27, 9, 3], dtype=np.uint8).reshape((1, 1, 4, 1))
603
+ qh = np.sum(qh, axis=-2).reshape((n_blocks, -1))
604
+ qs = np.concatenate([qs0, qs1, qh], axis=-1)
605
+ qs = (qs.astype(np.uint16) * 256 + (243 - 1)) // 243
606
+
607
+ qs = qs.astype(np.uint8)
608
+ d = d.astype(np.float16).view(np.uint8)
609
+
610
+ return np.concatenate([qs, d], axis=-1)
611
+
612
+ @classmethod
613
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
614
+ n_blocks = blocks.shape[0]
615
+
616
+ qs, rest = np.hsplit(blocks, [(QK_K - 4 * QK_K // 64) // 5])
617
+ qh, d = np.hsplit(rest, [QK_K // 64])
618
+
619
+ d = d.view(np.float16).astype(np.float32)
620
+
621
+ qs0, qs1 = qs[..., :32], qs[..., 32:]
622
+ qs0 = qs0.reshape((n_blocks, -1, 1, 32)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1))
623
+ qs0 = qs0.reshape((n_blocks, -1))
624
+ qs1 = qs1.reshape((n_blocks, -1, 1, 16)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1))
625
+ qs1 = qs1.reshape((n_blocks, -1))
626
+ qh = qh.reshape((n_blocks, -1, 1, 4)) * np.array([1, 3, 9, 27], dtype=np.uint8).reshape((1, 1, 4, 1))
627
+ qh = qh.reshape((n_blocks, -1))
628
+ qs = np.concatenate([qs0, qs1, qh], axis=-1)
629
+ qs = ((qs.astype(np.uint16) * 3) >> 8).astype(np.int8) - np.int8(1)
630
+
631
+ return (d * qs.astype(np.float32))
632
+
633
+
634
+ class TQ2_0(__Quant, qtype=GGMLQuantizationType.TQ2_0):
635
+ @classmethod
636
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
637
+ n_blocks = blocks.shape[0]
638
+
639
+ d = abs(blocks).max(axis=-1, keepdims=True)
640
+ with np.errstate(divide="ignore"):
641
+ id = np.where(d == 0, 0, 1 / d)
642
+ qs = np_roundf(blocks * id)
643
+ qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8)
644
+
645
+ qs = qs.reshape((n_blocks, -1, 4, 32)) << np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
646
+ qs = qs[..., 0, :] | qs[..., 1, :] | qs[..., 2, :] | qs[..., 3, :]
647
+ qs = qs.reshape((n_blocks, -1))
648
+
649
+ d = d.astype(np.float16).view(np.uint8)
650
+
651
+ return np.concatenate([qs, d], axis=-1)
652
+
653
+ @classmethod
654
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
655
+ n_blocks = blocks.shape[0]
656
+
657
+ qs, d = np.hsplit(blocks, [QK_K // 4])
658
+
659
+ d = d.view(np.float16).astype(np.float32)
660
+
661
+ qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
662
+ qs = (qs & 0x03).reshape((n_blocks, -1)).astype(np.int8) - np.int8(1)
663
+
664
+ return (d * qs.astype(np.float32))
665
+
666
+
667
+ class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
668
+ # e2m1 values (doubled)
669
+ # ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
670
+ kvalues = (0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12)
671
+
672
+ @staticmethod
673
+ # see ggml_e8m0_to_fp32_half in ggml-impl.h
674
+ def e8m0_to_fp32_half(x: np.ndarray) -> np.ndarray:
675
+ bits = np.where(x < 2, np.uint32(0x00200000) << np.uint32(x), np.uint32(x - 1) << np.uint32(23))
676
+ return bits.view(np.float32)
677
+
678
+ @classmethod
679
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
680
+ n_blocks = blocks.shape[0]
681
+
682
+ d = abs(blocks).max(axis=-1, keepdims=True)
683
+
684
+ with np.errstate(divide="ignore"):
685
+ e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8)
686
+
687
+ d = cls.e8m0_to_fp32_half(e)
688
+
689
+ kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 16))
690
+
691
+ errs = np.abs(d.reshape((n_blocks, 1, 1)) * kvalues.astype(np.float32) - blocks.reshape((n_blocks, cls.block_size, 1)))
692
+ best = np.argmin(errs, axis=-1, keepdims=True)
693
+
694
+ qs = best.reshape(n_blocks, 2, cls.block_size // 2).astype(np.uint8)
695
+ qs = qs[:, 0] | (qs[:, 1] << np.uint8(4))
696
+
697
+ qs = qs.reshape((n_blocks, cls.block_size // 2))
698
+
699
+ return np.concatenate([e, qs], axis=-1)
700
+
701
+ @classmethod
702
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
703
+ n_blocks = blocks.shape[0]
704
+
705
+ e, qs = np.hsplit(blocks, [1])
706
+
707
+ d = cls.e8m0_to_fp32_half(e)
708
+
709
+ qs = qs.reshape((n_blocks, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
710
+ qs = (qs & np.uint8(0x0F)).view(np.int8)
711
+
712
+ kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
713
+ qs = np.take_along_axis(kvalues, qs, axis=-1).reshape((n_blocks, cls.block_size))
714
+
715
+ return (d * qs.astype(np.float32))
716
+
717
+
718
+ class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
719
+ ksigns: bytes = (
720
+ b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f"
721
+ b"\x90\x11\x12\x93\x14\x95\x96\x17\x18\x99\x9a\x1b\x9c\x1d\x1e\x9f"
722
+ b"\xa0\x21\x22\xa3\x24\xa5\xa6\x27\x28\xa9\xaa\x2b\xac\x2d\x2e\xaf"
723
+ b"\x30\xb1\xb2\x33\xb4\x35\x36\xb7\xb8\x39\x3a\xbb\x3c\xbd\xbe\x3f"
724
+ b"\xc0\x41\x42\xc3\x44\xc5\xc6\x47\x48\xc9\xca\x4b\xcc\x4d\x4e\xcf"
725
+ b"\x50\xd1\xd2\x53\xd4\x55\x56\xd7\xd8\x59\x5a\xdb\x5c\xdd\xde\x5f"
726
+ b"\x60\xe1\xe2\x63\xe4\x65\x66\xe7\xe8\x69\x6a\xeb\x6c\xed\xee\x6f"
727
+ b"\xf0\x71\x72\xf3\x74\xf5\xf6\x77\x78\xf9\xfa\x7b\xfc\x7d\x7e\xff"
728
+ )
729
+
730
+ # iq2xxs_grid, but with each byte of the original packed in 2 bits,
731
+ # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2.
732
+ grid_shape = (256, 8)
733
+ grid_map = (0x08, 0x19, 0x2b)
734
+ grid_hex = (
735
+ b"00000200050008000a00110014002000220028002a0041004400500058006100"
736
+ b"6400800082008a00a20001010401100115014001840198010002020222028202"
737
+ b"010404041004210424044004420448046004810484049004a404000502050805"
738
+ b"200546056905800591050906100640068406a406000805080808140828084108"
739
+ b"440850085208880804094009020a140a01100410101021104010601084109010"
740
+ b"951000110811201150115a118011241245120014081420142514491480141815"
741
+ b"6215001616160118041810184018811800190519a019511a002002200a204420"
742
+ b"6120802082202921482100220222012404241024402456240025412564259026"
743
+ b"082820289428442a014004401040184021402440404048405640604081408440"
744
+ b"9040004120416141804185410142104248425642684200440844204480449944"
745
+ b"124524450046014804481048404845480049584961498249454a904a00500850"
746
+ b"1150195020508050885004514251a4519152905492540a550156545600581158"
747
+ b"195864584059085a046010604060686000615561186260620064056410651265"
748
+ b"84654268008002800a8041808280048118814081118201840484108415844084"
749
+ b"608400854685948509864086608602880489118a0490109024904090a1901691"
750
+ b"8091459200942294449451958198209902a050a085a009a100a218a450a804a9"
751
+ )
752
+
753
+ @classmethod
754
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
755
+ n_blocks = blocks.shape[0]
756
+
757
+ d, qs = np.hsplit(blocks, [2])
758
+
759
+ d = d.view(np.float16).astype(np.float32)
760
+
761
+ qs = qs.view(np.uint32).reshape(n_blocks, -1, 2)
762
+
763
+ db = d * (np.float32(0.5) + (qs[..., 1] >> 28).astype(np.float32)) * np.float32(0.25)
764
+ db = db.reshape((n_blocks, -1, 1, 1))
765
+
766
+ # get the sign indices and unpack the bits
767
+ signs = qs[..., 1].reshape((n_blocks, -1, 1)) >> np.array([0, 7, 14, 21], dtype=np.uint32).reshape((1, 1, 4))
768
+ ksigns = np.frombuffer(cls.ksigns, dtype=np.uint8).reshape((1, 1, 1, 128))
769
+ signs = (signs & np.uint32(0x7F)).reshape((n_blocks, -1, 4, 1))
770
+ signs = np.take_along_axis(ksigns, signs, axis=-1)
771
+ signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 1, 8))
772
+ signs = signs & np.uint8(0x01)
773
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
774
+ signs = signs.reshape((n_blocks, -1, 4, 8))
775
+
776
+ assert cls.grid is not None
777
+ grid = np.take_along_axis(cls.grid, qs[..., 0].copy().view(np.uint8).reshape((n_blocks, -1, 1, 1)), axis=-2)
778
+ grid = grid.reshape((n_blocks, -1, 4, 8))
779
+
780
+ return (db * grid * signs).reshape((n_blocks, -1))
781
+
782
+
783
+ class IQ2_XS(__Quant, qtype=GGMLQuantizationType.IQ2_XS):
784
+ # iq2xs_grid, but with each byte of the original packed in 2 bits,
785
+ # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2.
786
+ grid_shape = (512, 8)
787
+ grid_map = (0x08, 0x19, 0x2b)
788
+ grid_hex = (
789
+ b"00000200050008000a0011001400160019002000220025002800410044004600"
790
+ b"49005000520055005800610064008000820085008800910094009900a0000101"
791
+ b"04010601090110011201150118011a0121012401400142014501480151015401"
792
+ b"6001680181018401900100020202050208021102140220024102440250025502"
793
+ b"80028a0201040404060409041004120415041804210424044004420445044804"
794
+ b"5104540456046004810484049004000502050505080511051405200541054405"
795
+ b"500561058005010604061006260640064206840600080208050808080a081108"
796
+ b"14082008250841084408500858088008a008aa08010904091009400981098909"
797
+ b"000a200a280a960aa00a01100410061009101010121015101810211024104010"
798
+ b"4210451048105110541060106a10811084109010001102110511081111111411"
799
+ b"2011411144115011801194119611011204120612101240126012001402140514"
800
+ b"0814111414142014411444144914501464148014011504151015401500161416"
801
+ b"49160118041810181218401854188618001905196619511aa91a002002200520"
802
+ b"08200a201120142020204120442050208020a020012104211021402148216521"
803
+ b"002222228022a82201240424102429244024002541255225992501261a26a626"
804
+ b"002808280a28202855288828a22868299029082a202a822a882a8a2a01400440"
805
+ b"0640094010401240154018402140244040404240454048404a40514054406040"
806
+ b"6540814084409040004102410541084111411441204141414441504180418541"
807
+ b"a241014204421042124229424042004402440544084411441444194420444144"
808
+ b"4444504480449444014504451045244540459a4500460a464446504601480448"
809
+ b"1048404845485448624800491149444950496949044a00500250055008501150"
810
+ b"145020502850415044505050805001510451105115514051425100524452aa52"
811
+ b"0154045410542154405460548154a154005508558055885521566856a1560058"
812
+ b"14584158505899581a5940594259855a0160046010604060546062608660a960"
813
+ b"006124624a62926200641664106540654565a46501686a682569066a546a626a"
814
+ b"00800280058008801180148020802a8041804480508080808280a880aa800181"
815
+ b"0481068110814081518159810082208280828282a082a8820184048410841284"
816
+ b"158440846084898400854485a58518866a860088088825885a8880888288a888"
817
+ b"0689228a808a888a968aa88a0190049010904090569084900091229164915692"
818
+ b"89920094059444945094589429959095929541965198a6984999159a609a00a0"
819
+ b"02a008a00aa020a02aa0a0a051a159a1a6a100a202a208a22aa280a2a0a240a4"
820
+ b"95a465a698a60aa820a822a828a8a0a8a8a804a984a986a928aa2aaa91aaaaaa"
821
+ )
822
+
823
+ @classmethod
824
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
825
+ n_blocks = blocks.shape[0]
826
+
827
+ d, rest = np.hsplit(blocks, [2])
828
+ qs, scales = np.hsplit(rest, [2 * QK_K // 8])
829
+
830
+ d = d.view(np.float16).astype(np.float32)
831
+ qs = qs.view(np.uint16)
832
+
833
+ scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
834
+ scales = (scales & 0x0F).reshape((n_blocks, -1))
835
+ db = d * (np.float32(0.5) + scales) * np.float32(0.25)
836
+ db = db.reshape((n_blocks, -1, 1, 1))
837
+
838
+ # get the sign indices and unpack the bits
839
+ signs = np.frombuffer(IQ2_XXS.ksigns, dtype=np.uint8).reshape(1, 1, 128)
840
+ signs = np.take_along_axis(signs, (qs >> 9).reshape((n_blocks, -1, 1)), axis=-1)
841
+ signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8))
842
+ signs = signs & np.uint8(0x01)
843
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
844
+ signs = signs.reshape((n_blocks, -1, 2, 8))
845
+
846
+ assert cls.grid is not None
847
+ grid = np.take_along_axis(cls.grid, (qs & np.uint16(511)).reshape((n_blocks, -1, 1, 1)), axis=-2)
848
+ grid = grid.reshape((n_blocks, -1, 2, 8))
849
+
850
+ return (db * grid * signs).reshape((n_blocks, -1))
851
+
852
+
853
+ class IQ2_S(__Quant, qtype=GGMLQuantizationType.IQ2_S):
854
+ # iq2s_grid, but with each byte of the original packed in 2 bits,
855
+ # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2.
856
+ grid_shape = (1024, 8)
857
+ grid_map = (0x08, 0x19, 0x2b)
858
+ grid_hex = (
859
+ b"00000200050008000a0011001400160019002000220025002800410044004600"
860
+ b"490050005200550058006100640066006900800082008500880091009400a000"
861
+ b"a500aa0001010401060109011001120115011801210124014001420145014801"
862
+ b"510154015601590160016501680181018401900192019501a101a40100020202"
863
+ b"050208021102140220022a02410244024602490250025502800285028a029402"
864
+ b"a202010404040604090410041204150418042104240426042904400442044504"
865
+ b"48044a0451045404560459046004620465048104840486048904900495049804"
866
+ b"a104a40400050205050508050a05110514051605190520052505280541054405"
867
+ b"46054905500552055505580561056405800582058505880591059405a0050106"
868
+ b"0406060609061006150640064506480651065406600681068406900600080208"
869
+ b"050808081108140816081908200825082a084108440846084908500852085508"
870
+ b"580861086408800885089408aa08010904091009120915091809210940094509"
871
+ b"480951095409600981099009000a110a140a220a280a2a0a500a990a01100410"
872
+ b"0610091010101210151018102110241026104010421045104810511054105610"
873
+ b"59106010621065106810811084108610901095109810a110a410001102110511"
874
+ b"08110a1111111411161119112011221125112811411144114611491150115211"
875
+ b"5511581161116411801182118511881191119411011204120912101215122112"
876
+ b"2412401245125112541281128412901200140214051408141114141416141914"
877
+ b"2014251428144114441446144914501452145514581461146414801482148514"
878
+ b"881491149414a014011504150615091510151215151518152115241540154215"
879
+ b"4515481551155415601581158415901500160516081611161416201641164416"
880
+ b"50168016aa160118041806180918101815181818211840184218451848185118"
881
+ b"541860188118841800190219051908191119141920194119441950196919a219"
882
+ b"041a101a401a561a00200220052008201120142016201920202025202a204120"
883
+ b"4420502052205520642080208a209420aa200121042110211221152121214021"
884
+ b"4221452151215421602181218421902100220a22222228222a22442250228822"
885
+ b"8a22a82201240424062409241024152418242124242440244224452448245124"
886
+ b"5424602481248424902400250525082511251425202541254425502566258025"
887
+ b"0126042610264026592600280528112814284128442850288a28aa2801290429"
888
+ b"102995290a2a222a642a882a8a2a014004400640094010401240154018401a40"
889
+ b"21402440264040404240454048404a4051405440564059406040624065408140"
890
+ b"8440904095409840a140a4400041024105410841114114411641194120412241"
891
+ b"2541414144414641494150415241554158416141644180418241854188419141"
892
+ b"9441a04101420442104212421542184224424042454248425142544260428142"
893
+ b"844200440244054408440a441144144416441944204422442544284441444444"
894
+ b"46444944504452445544584461446444804482448544884491449444a0440145"
895
+ b"0445064509451045124515451845214524454045424545454845514554456045"
896
+ b"6a4581458445904500460246054608461146144620464146444650468046a546"
897
+ b"0148044809481048124815481848214824484048424845484848514854486048"
898
+ b"84489048004902490549084911491449204941494449504980499649014a044a"
899
+ b"104a404a00500250055008501150145016501950205022502550285041504450"
900
+ b"4650495050505250555058506150645080508250855088509150945001510451"
901
+ b"0651095110511251155118512151245140514251455148515151545160518151"
902
+ b"8451905100520552085211521452205241524452505269528052015404540654"
903
+ b"0954105412541554185421542454405442544554485451545454605481548454"
904
+ b"9054005502550555085511551455205541554455505580550156045610562656"
905
+ b"405600580258055808581158145820584158445850585a588058015904591059"
906
+ b"4059005a195a855aa85a01600460066010601260156018602160246040604560"
907
+ b"4860516054606060846090600061026105610861116114612061416144615061"
908
+ b"806199610462106240625662a162006405640864116414642064416444645064"
909
+ b"806401650465106540654a656865926500669466016804681068656898680069"
910
+ b"2a69426aa16a0080028005800880118014801980208025804180448050805280"
911
+ b"5580588061808080858091809480018104810981108112811581188121812481"
912
+ b"408142814581488151815481818184819081a981008205820a82118214824182"
913
+ b"4482508201840484068409841084128415841884218440844284458448845184"
914
+ b"5484608481848484908400850285058508851185148520854185448550858085"
915
+ b"8a85018604861086298640860088058811881488418844885088a28801890489"
916
+ b"40896589228a588a5a8a828aa28a019004900990109012901590189024904090"
917
+ b"4290459048905190549060908190849090900091059111911491419144915091"
918
+ b"5a910192049210924092a6920094029405940894119414942094419444945094"
919
+ b"8094969401950495109540959895a19500964696649601980498109826984098"
920
+ b"a998009949995299909a00a005a00aa014a022a02aa041a044a050a0a2a0aaa0"
921
+ b"40a165a102a20aa222a228a22aa282a288a28aa2a8a201a404a410a440a489a4"
922
+ b"a4a400a519a551a60aa828a8a2a854a986a908aa0aaa20aa22aa28aa88aaaaaa"
923
+ )
924
+
925
+ @classmethod
926
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
927
+ n_blocks = blocks.shape[0]
928
+
929
+ d, rest = np.hsplit(blocks, [2])
930
+ qs, rest = np.hsplit(rest, [QK_K // 8])
931
+ signs, rest = np.hsplit(rest, [QK_K // 8])
932
+ qh, scales = np.hsplit(rest, [QK_K // 32])
933
+
934
+ d = d.view(np.float16).astype(np.float32)
935
+
936
+ scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
937
+ scales = (scales & 0x0F).reshape((n_blocks, -1))
938
+ db = d * (np.float32(0.5) + scales) * np.float32(0.25)
939
+ db = db.reshape((n_blocks, -1, 1, 1))
940
+
941
+ # unpack the sign bits
942
+ signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8))
943
+ signs = signs & np.uint8(0x01)
944
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
945
+ signs = signs.reshape((n_blocks, -1, 2, 8))
946
+
947
+ qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4))
948
+ qs = qs.astype(np.uint16) | ((qh & 0x03).astype(np.uint16) << 8).reshape((n_blocks, -1))
949
+
950
+ assert cls.grid is not None
951
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
952
+ grid = grid.reshape((n_blocks, -1, 2, 8))
953
+
954
+ return (db * grid * signs).reshape((n_blocks, -1))
955
+
956
+
957
+ class IQ3_XXS(__Quant, qtype=GGMLQuantizationType.IQ3_XXS):
958
+ grid_shape = (256, 4)
959
+ grid_map = (0x04, 0x0c, 0x14, 0x1c, 0x24, 0x2c, 0x34, 0x3e)
960
+ grid_hex = (
961
+ b"0000020004001100130017002000220031004200730075000101030110011201"
962
+ b"2101250130013201410154017001000202020402110220022202310233023702"
963
+ b"5102570275020103070310031203250370031304370444045704730475040105"
964
+ b"0705320552053506640610071407160743076107011003101010121021102310"
965
+ b"3010321034104710501000110211111120112211011203121012121221123012"
966
+ b"7212001302132013311346136613011405145014201524154615711505162217"
967
+ b"4017002002201120132020202220262031204220012103210521102112212121"
968
+ b"3021632167217021002202221122172220222222372240225522012310231423"
969
+ b"7023742335245324032527254125742501270327162745270130103012302130"
970
+ b"2330503065307230003102312031313144314631013203321032253252327232"
971
+ b"1133333330344734723400350635223555351436363663363337603704401740"
972
+ b"3540374053405740744120423742404260426642074345430444514464442545"
973
+ b"4345704505471047124730471250415070500051065126515551145232527252"
974
+ b"0253535310542354275472540255315550562457425724604460466064602161"
975
+ b"6161176264623063366344640565526533660367216703700570077010703270"
976
+ b"5270267140711272457252720073157333736073217441740075027524753076"
977
+ )
978
+
979
+ @classmethod
980
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
981
+ n_blocks = blocks.shape[0]
982
+
983
+ d, rest = np.hsplit(blocks, [2])
984
+ qs, scales = np.hsplit(rest, [QK_K // 4])
985
+
986
+ d = d.view(np.float16).astype(np.float32)
987
+ scales = scales.view(np.uint32)
988
+
989
+ db = d * (np.float32(0.5) + (scales >> 28).astype(np.float32)) * np.float32(0.5)
990
+ db = db.reshape((n_blocks, -1, 1, 1))
991
+
992
+ # get the sign indices and unpack the bits
993
+ signs = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 7, 14, 21], dtype=np.uint32).reshape((1, 1, 4))
994
+ ksigns = np.frombuffer(IQ2_XXS.ksigns, dtype=np.uint8).reshape((1, 1, 1, 128))
995
+ signs = (signs & np.uint32(0x7F)).reshape((n_blocks, -1, 4, 1))
996
+ signs = np.take_along_axis(ksigns, signs, axis=-1)
997
+ signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 1, 8))
998
+ signs = signs & np.uint8(0x01)
999
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
1000
+ signs = signs.reshape((n_blocks, -1, 4, 8))
1001
+
1002
+ assert cls.grid is not None
1003
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
1004
+ grid = grid.reshape((n_blocks, -1, 4, 8))
1005
+
1006
+ return (db * grid * signs).reshape((n_blocks, -1))
1007
+
1008
+
1009
+ class IQ3_S(__Quant, qtype=GGMLQuantizationType.IQ3_S):
1010
+ grid_shape = (512, 4)
1011
+ grid_map = (0x01, 0x03, 0x05, 0x07, 0x09, 0x0b, 0x0d, 0x0f)
1012
+ grid_hex = (
1013
+ b"0000010002000500070010001100120014001600200021002500330040004200"
1014
+ b"4500470051005300600062007100740077000001010102010401100111011501"
1015
+ b"2001230127013101350144016101650172010002010205020702100213021602"
1016
+ b"2102250230023402420245024702510253027002730203031103150320032203"
1017
+ b"3103330336034403500352036703710375030004130417042104240432044004"
1018
+ b"4304510470040205040520052205260533054105450547056605730506061106"
1019
+ b"1306310652067106000702070407200722072607330750075407001001100210"
1020
+ b"0410101011101310151017102010221031103410361054105610611072100011"
1021
+ b"0111031106111011141121113011331141115011521170117611001212121512"
1022
+ b"1712201224123212401243125512601272120113041307131013131321132713"
1023
+ b"3013341341136213701303140514121414143114331442144614501454140115"
1024
+ b"1015131521153015321551152016241627164416461601170317101712172117"
1025
+ b"3517411762177017002001200320052007201020122014201620212023202720"
1026
+ b"3020322041204320452050205220672070207320752000210221102113211721"
1027
+ b"2221252131213421422151210122042207222122232230223722412253225722"
1028
+ b"7122742200230223052311232223242331233323422350236623012407242024"
1029
+ b"2324322435244124722475240425112522253725402553257025002602260726"
1030
+ b"2126552661260527112726273027432750270230113013301530173022303130"
1031
+ b"3330353042304430473051306330713001310331053114312131233140316031"
1032
+ b"7231763100321232203232323432503201331033143321332333273330334133"
1033
+ b"4333473355337333033411341634223431345234603464340135103512352535"
1034
+ b"3235443556357335163641360137033720372237353700400440124020402440"
1035
+ b"2740324041405040704002410741114113412241304135414341514155410142"
1036
+ b"0342104215422142334240425742624270420443114313432043224331433543"
1037
+ b"0044024424443744404471440545074521456245134634466046104715473047"
1038
+ b"4347514702501050145022504050445047505250665074500151035105511251"
1039
+ b"2151325172510052115223523052365253520253075310532753445351536553"
1040
+ b"7353015404542054325446541255265551555355425602570457225711601360"
1041
+ b"1560316033606060006120612761646112623462426255626262706200631463"
1042
+ b"2163406325644364626400650365346560650566406611671367007004700770"
1043
+ b"2070227036704070547062700271117124714371457101720472107216722172"
1044
+ b"3072517202733273357353730174057413742074507422754275027631760077"
1045
+ )
1046
+
1047
+ @classmethod
1048
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
1049
+ n_blocks = blocks.shape[0]
1050
+
1051
+ d, rest = np.hsplit(blocks, [2])
1052
+ qs, rest = np.hsplit(rest, [QK_K // 4])
1053
+ qh, rest = np.hsplit(rest, [QK_K // 32])
1054
+ signs, scales = np.hsplit(rest, [QK_K // 8])
1055
+
1056
+ d = d.view(np.float16).astype(np.float32)
1057
+
1058
+ scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
1059
+ scales = (scales & 0x0F).reshape((n_blocks, -1))
1060
+ db = d * (1 + 2 * scales)
1061
+ db = db.reshape((n_blocks, -1, 1, 1))
1062
+
1063
+ # unpack the sign bits
1064
+ signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8))
1065
+ signs = signs & np.uint8(0x01)
1066
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
1067
+ signs = signs.reshape((n_blocks, -1, 4, 8))
1068
+
1069
+ qh = qh.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8)
1070
+ qh = (qh & 0x01).astype(np.uint16).reshape((n_blocks, -1))
1071
+ qs = qs.astype(np.uint16) | (qh << 8)
1072
+
1073
+ assert cls.grid is not None
1074
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
1075
+ grid = grid.reshape((n_blocks, -1, 4, 8))
1076
+
1077
+ return (db * grid * signs).reshape((n_blocks, -1))
1078
+
1079
+
1080
+ class IQ1_S(__Quant, qtype=GGMLQuantizationType.IQ1_S):
1081
+ # iq1s_grid, with each byte packed into 2 bits
1082
+ # -1, 0, 1 <=> 0, 1, 2
1083
+ grid_shape = (2048, 8)
1084
+ grid_map = (-1, 0, 1)
1085
+ grid_hex = (
1086
+ b"00000200050008000a00110015002000220028002a0045005100540056006500"
1087
+ b"8000820088008a009500a000a200a800aa000401050111011401160119011a01"
1088
+ b"2501410146014901520155015a0161016401660168018501910194019601a501"
1089
+ b"0002020208020a0215022002220228022a024502510259026402690280028202"
1090
+ b"88028a02910295029902a002a202a802aa021104140416042504410449045504"
1091
+ b"5a046404650491049904a5040105040505050605150518051a05290540054505"
1092
+ b"4a0550055105540555055605590560056205650568056a058105910595059805"
1093
+ b"9a05a105a405a505a605a9051406190641064406500652065506580660066106"
1094
+ b"6606690685069106940699060008020808080a0815082008220828082a084508"
1095
+ b"5108560865088008820888088a089508a008a208a808aa080509110914091909"
1096
+ b"2409250941095009510955096109640969099109940996099909a509000a020a"
1097
+ b"080a0a0a150a200a220a280a2a0a450a510a590a610a650a800a820a850a880a"
1098
+ b"8a0a950aa00aa20aa80aaa0a1010111014101910241025104110441050105510"
1099
+ b"58106110641065106910911094109610a110a510011104110611091110111211"
1100
+ b"1511181121112411291145114a11501151115211541155115611591160116511"
1101
+ b"841192119511a111a41111121412161225124012461249125212551258125a12"
1102
+ b"641266128512911294129612a512011406140914141415141814191421142614"
1103
+ b"41144514461448144a1451145414551456145914621465146814841489149014"
1104
+ b"94149514981499149a14a114a414a514a914021505150a151115141515151615"
1105
+ b"191520152215251528152a154115441545154615511552155415551556155915"
1106
+ b"5a1561156415651566156915801582158415851588158a159015911594159515"
1107
+ b"961599159a15a015a215a51501160416051606161516161618161a1621162616"
1108
+ b"401642164416451648164a165116551656165816591661166416651668166916"
1109
+ b"6a1686168a1692169516a416a916111816182518411844184618491850185518"
1110
+ b"58185a1860186118641866186918851891189418a5181019121915191a192119"
1111
+ b"25194219441945194819511954195519561959195a19601965196a1989199119"
1112
+ b"921995199819a119a619a919091a161a241a261a441a461a491a501a521a551a"
1113
+ b"581a611a661a691a851a911a961a9a1a0020022008200a201520202022202520"
1114
+ b"28202a20452051205920612065208020822088208a209520a020a220a520a820"
1115
+ b"aa2005211121142119212521422144214921552158215a216121642165216621"
1116
+ b"8521902196219921a521012208220a22112215222022222228222a2245225122"
1117
+ b"562259226522812288228a2291229522a022a222a822aa220524142416241924"
1118
+ b"252444244524462449245224552458245a2466248524912494249924a124a524"
1119
+ b"0925152521252925402545254825512554255525592562256525682589259025"
1120
+ b"9425952598259a25a125a425a625a92505261026122619262526412649265526"
1121
+ b"6026612669268426862690269a260028022808280a2815282028222828282a28"
1122
+ b"45285128542865288028822888288a28a028a228a828aa280929112914291929"
1123
+ b"2529462949295229552961296429662969298529902996299929a429a529002a"
1124
+ b"022a082a0a2a202a222a282a2a2a452a512a562a592a652a802a822a882a8a2a"
1125
+ b"952aa02aa22aa82aaa2a054011401640254049405240554058405a4061406440"
1126
+ b"664094409940a140a6400041014104410641094112411541164118411a412141"
1127
+ b"26412941454148414a41514154415541564159415a41654168416a4181418441"
1128
+ b"8641904192419541a041a141a241054211421442164225424142524255425a42"
1129
+ b"6442694289429442a5420144154419442944454448444a445144544455445644"
1130
+ b"61446244654468446a44814486448944904492449544a044a144a94401450245"
1131
+ b"05450a4511451445154516451945204525452a45414544454545464549455045"
1132
+ b"5145544555455645584559456145644565456645694582458445854588459145"
1133
+ b"94459545964599459a45a545a845aa450146054609461446154618461a462146"
1134
+ b"2446294640464246454648465046514652465546564659466246654668468146"
1135
+ b"85468a4694469546a146a446a6460548114815481a4825484248494850485548"
1136
+ b"5848614864486648694885489148944896489948a5480149054906490a491049"
1137
+ b"144915491849214924492649404945494a495149524954495549564959496049"
1138
+ b"6249654966496a49864989499249954996499849a149a449a649a949164a444a"
1139
+ b"464a494a554a584a5a4a644a694a944aa54a0150045005500650095012501550"
1140
+ b"1a50215024502950405045504850515054505550565059506550685086508950"
1141
+ b"95509850a050a150a650a9500551085109510a51115114511551165118511951"
1142
+ b"20512551265128512a5141514451455146514951505151515251545155515651"
1143
+ b"585159515a51615164516551665169518251855191519451955196519951a051"
1144
+ b"a551aa5101520652125215521a5221522452425245524a525152545255525652"
1145
+ b"595262526552855290529252955299529a52a452045405541154145415541654"
1146
+ b"185419542154255428542a54415444544554465449544a545054515454545554"
1147
+ b"5654585459545a54615462546454655466546954805488548a54915494549554"
1148
+ b"96549954a154a454a554aa540155025504550555065509551055115512551455"
1149
+ b"1555165519551a55215524552555265529554055415542554455455546554855"
1150
+ b"4955505551555255545555555655585559555a55605561556455655566556855"
1151
+ b"69556a5581558455855589558a559055915594559555965598559955a155a455"
1152
+ b"a555a655a9550056015602560456065608560956115614561556185619562056"
1153
+ b"2156225624562556265628562956415645564656485649564a56505651565256"
1154
+ b"545655565656585659565a566156645665566956825685568656885689568a56"
1155
+ b"915695569a56a256a556a656a856a95604580558065809581058155818582158"
1156
+ b"2a58455848584a58515854585558565858585958605862586458655882588958"
1157
+ b"9058925895589858a158a9580159025905590a59115914591559165919592559"
1158
+ b"41594459455946594959505951595259545955595659585959595a5961596459"
1159
+ b"655966596959815985598959915994599559965998599959a559045a085a155a"
1160
+ b"1a5a205a255a265a295a455a485a495a515a555a565a585a595a625a655a685a"
1161
+ b"6a5a815a8a5a925a955a965a985a9a5aa15a0560146016601960256044605060"
1162
+ b"5560566058605a60616064606660696081609660a56001610461066109611261"
1163
+ b"15612161226126612961456149615161556156615961656166616a6184618a61"
1164
+ b"92619561a161a661a96111621662196240624162466255625662586260628562"
1165
+ b"91629662a56211641264156416641a6421642664296440644264456448644a64"
1166
+ b"516454645564566459645a646064626465648464856489649064926494649564"
1167
+ b"966498649a64a164a464a964056508650a651165156516651965446545654665"
1168
+ b"496550655165546555655665596561656465656566656965866589658a659165"
1169
+ b"9565966599659a65a265a565a665a86502660966156620662666286629664066"
1170
+ b"456648664a66516654665566566658665a666066656668668066826685668a66"
1171
+ b"9466966698669966a066a466a666aa661668196825684168526855685a686168"
1172
+ b"6968856891689868a66801690469106915692169246926692969406941694569"
1173
+ b"4669486951695469556956695969606965696a69826984698a699569a169a469"
1174
+ b"a569a969116a166a186a416a446a496a506a556a586a5a6a646a656a696a866a"
1175
+ b"946a986a9a6aa66a0080028008800a802080228028802a804580508051805480"
1176
+ b"5680598065808080828088808a809580a080a280a880aa800581118114811681"
1177
+ b"1981258141814481498150815281558156815881598164816681698185818981"
1178
+ b"948196819981a5810082028208820a8215822082228228822a82518254825982"
1179
+ b"65828082828288828a829582a082a282a882aa82148419844184448451845584"
1180
+ b"5a846184648469849484998401850985128515851a8526852985408541854585"
1181
+ b"4885518554855585568559855a856585668568856a8581858485868589859085"
1182
+ b"928595859885a68511861686198625864186448649864a865086558659865a86"
1183
+ b"618666866a86858691869a86a4860088028808880a8815882088228828882a88"
1184
+ b"41884588518854885988658869888088828888888a889588a088a288a888aa88"
1185
+ b"05890689118914891689258941894489468949895089528955895a8961896489"
1186
+ b"858996899989a589008a028a088a0a8a158a208a228a288a2a8a458a518a548a"
1187
+ b"568a808a828a888a8a8a958aa08aa28aa88aaa8a059011901690189019902590"
1188
+ b"419046904990559058905a9069906a9085909190949096909990a59001910491"
1189
+ b"069109911091159118911a912191249126912991409145915091519154915591"
1190
+ b"569159916291659184918691929195919891a191a491a691a991059211921492"
1191
+ b"19922592449246924992509252925592589266926992859294929692a9920194"
1192
+ b"04940694109415941894269440944a9451945494559456945894599460946194"
1193
+ b"62946594849486949294949495949894a194a9940095059508950a9510951195"
1194
+ b"14951595169519952195259529952a9541954495459546954995509551955295"
1195
+ b"549555955695589559955a956195649565956695699581958595889591959295"
1196
+ b"94959595969599959a95a095a295a595a895aa95019604961096159619962096"
1197
+ b"2696299645964896499651965296559656965996659668968296849689968a96"
1198
+ b"929694969596a496a696a9960598169819982598419846985098529855985698"
1199
+ b"5a98649865988598919896989998a59804990699099910991299159918991a99"
1200
+ b"209921992499269940994299459948994a995199549955995699599962996599"
1201
+ b"66996a99819984999099929995999a99a199a699059a159a259a449a469a499a"
1202
+ b"509a559a589a619a859a919a949a959a969a00a002a008a00aa015a020a022a0"
1203
+ b"28a02aa045a051a054a056a059a080a082a088a08aa095a0a0a0a2a0a8a0aaa0"
1204
+ b"05a109a111a114a116a119a11aa146a149a151a155a158a15aa161a164a185a1"
1205
+ b"90a192a196a199a102a208a20aa210a219a222a228a22aa245a251a256a259a2"
1206
+ b"65a280a282a288a28aa295a2a0a2a2a2a8a2aaa219a425a441a444a450a454a4"
1207
+ b"55a458a45aa461a465a466a468a469a485a406a509a510a512a515a518a526a5"
1208
+ b"29a542a545a551a554a555a556a559a565a56aa581a584a585a586a589a592a5"
1209
+ b"95a598a505a611a616a61aa621a625a644a646a64aa652a655a656a658a660a6"
1210
+ b"62a686a690a695a696a699a6a1a6a4a6a6a600a802a808a80aa820a822a828a8"
1211
+ b"2aa851a854a856a859a880a882a888a88aa895a8a0a8a2a8a8a8aaa805a914a9"
1212
+ b"19a921a925a941a950a955a95aa961a966a969a990a996a900aa02aa08aa0aaa"
1213
+ b"20aa22aa28aa2aaa51aa54aa56aa80aa82aa88aa8aaa95aaa0aaa2aaa8aaaaaa"
1214
+ )
1215
+
1216
+ delta = np.float32(0.125)
1217
+
1218
+ @classmethod
1219
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
1220
+ n_blocks = blocks.shape[0]
1221
+
1222
+ d, rest = np.hsplit(blocks, [2])
1223
+ qs, qh = np.hsplit(rest, [QK_K // 8])
1224
+
1225
+ d = d.view(np.float16).astype(np.float32)
1226
+ qh = qh.view(np.uint16)
1227
+
1228
+ dl = d * (2 * ((qh >> 12) & 7) + 1)
1229
+ dl = dl.reshape((n_blocks, -1, 1, 1))
1230
+ delta = np.where((qh & np.uint16(0x8000)) == 0, cls.delta, -cls.delta)
1231
+ delta = delta.reshape((n_blocks, -1, 1, 1))
1232
+
1233
+ qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 3, 6, 9], dtype=np.uint16).reshape((1, 1, 4))
1234
+ qs = qs.astype(np.uint16) | ((qh & 7) << 8).reshape((n_blocks, -1))
1235
+
1236
+ assert cls.grid is not None
1237
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
1238
+ grid = grid.reshape((n_blocks, -1, 4, 8))
1239
+
1240
+ return (dl * (grid + delta)).reshape((n_blocks, -1))
1241
+
1242
+
1243
+ class IQ1_M(__Quant, qtype=GGMLQuantizationType.IQ1_M):
1244
+ grid_shape = IQ1_S.grid_shape
1245
+ grid_map = IQ1_S.grid_map
1246
+ grid_hex = IQ1_S.grid_hex
1247
+
1248
+ delta = IQ1_S.delta
1249
+
1250
+ # Okay *this* type is weird. It's the only one which stores the f16 scales in multiple parts.
1251
+ @classmethod
1252
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
1253
+ n_blocks = blocks.shape[0]
1254
+
1255
+ qs, rest = np.hsplit(blocks, [QK_K // 8])
1256
+ qh, scales = np.hsplit(rest, [QK_K // 16])
1257
+
1258
+ # The f16 scale is packed across multiple bytes
1259
+ scales = scales.view(np.uint16)
1260
+ d = (scales.reshape((n_blocks, 4)) & np.uint16(0xF000)) >> np.array([12, 8, 4, 0], dtype=np.uint16).reshape((1, 4))
1261
+ d = d[..., 0] | d[..., 1] | d[..., 2] | d[..., 3]
1262
+ d = d.view(np.float16).astype(np.float32).reshape((n_blocks, 1))
1263
+
1264
+ scales = scales.reshape(n_blocks, -1, 1) >> np.array([0, 3, 6, 9], dtype=np.uint16).reshape((1, 1, 4))
1265
+ scales = (scales & 0x07).reshape((n_blocks, -1))
1266
+ dl = d * (2 * scales + 1)
1267
+ dl = dl.reshape((n_blocks, -1, 2, 1, 1))
1268
+
1269
+ qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
1270
+ qs = qs.astype(np.uint16) | ((qh & 0x07).astype(np.uint16) << 8).reshape((n_blocks, -1))
1271
+
1272
+ delta = np.where(qh & 0x08 == 0, cls.delta, -cls.delta)
1273
+ delta = delta.reshape((n_blocks, -1, 2, 2, 1))
1274
+
1275
+ assert cls.grid is not None
1276
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
1277
+ grid = grid.reshape((n_blocks, -1, 2, 2, 8))
1278
+
1279
+ return (dl * (grid + delta)).reshape((n_blocks, -1))
1280
+
1281
+
1282
+ class IQ4_NL(__Quant, qtype=GGMLQuantizationType.IQ4_NL):
1283
+ kvalues = (-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113)
1284
+
1285
+ @classmethod
1286
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
1287
+ n_blocks = blocks.shape[0]
1288
+
1289
+ d, qs = np.hsplit(blocks, [2])
1290
+
1291
+ d = d.view(np.float16).astype(np.float32)
1292
+
1293
+ qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
1294
+
1295
+ qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 1))
1296
+
1297
+ kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
1298
+ qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1))
1299
+
1300
+ return (d * qs)
1301
+
1302
+
1303
+ class IQ4_XS(__Quant, qtype=GGMLQuantizationType.IQ4_XS):
1304
+ @classmethod
1305
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
1306
+ n_blocks = blocks.shape[0]
1307
+
1308
+ d, rest = np.hsplit(blocks, [2])
1309
+ scales_h, rest = np.hsplit(rest, [2])
1310
+ scales_l, qs = np.hsplit(rest, [QK_K // 64])
1311
+
1312
+ d = d.view(np.float16).astype(np.float32)
1313
+ scales_h = scales_h.view(np.uint16)
1314
+
1315
+ scales_l = scales_l.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
1316
+ scales_h = scales_h.reshape((n_blocks, 1, -1)) >> np.array([2 * i for i in range(QK_K // 32)], dtype=np.uint16).reshape((1, -1, 1))
1317
+ scales_l = scales_l.reshape((n_blocks, -1)) & np.uint8(0x0F)
1318
+ scales_h = scales_h.reshape((n_blocks, -1)).astype(np.uint8) & np.uint8(0x03)
1319
+
1320
+ scales = (scales_l | (scales_h << np.uint8(4))).astype(np.int8) - np.int8(32)
1321
+ dl = (d * scales.astype(np.float32)).reshape((n_blocks, -1, 1))
1322
+
1323
+ qs = qs.reshape((n_blocks, -1, 1, 16)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
1324
+ qs = qs.reshape((n_blocks, -1, 32, 1)) & np.uint8(0x0F)
1325
+
1326
+ kvalues = np.array(IQ4_NL.kvalues, dtype=np.int8).reshape((1, 1, 1, -1))
1327
+ qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 32))
1328
+
1329
+ return (dl * qs).reshape((n_blocks, -1))
1330
+
1331
+ # =============================================================================
1332
+ # Helicoidal-Zeta Kernel (Bruno Becker) — OFFELLIA Architecture
1333
+ # =============================================================================
1334
+ from dataclasses import dataclass, field
1335
+ import numpy as np
1336
+
1337
+ try:
1338
+ from mpmath import mp, zeta as _mp_zeta
1339
+ except Exception:
1340
+ mp = None
1341
+ _mp_zeta = None
1342
+
1343
+ def _require_mpmath() -> None:
1344
+ if mp is None or _mp_zeta is None:
1345
+ raise ImportError("mpmath is required for zeta_signature(). Install with: pip install mpmath")
1346
+
1347
+ @dataclass
1348
+ class _PrimeCache:
1349
+ primes: list[int] = field(default_factory=lambda: [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41])
1350
+ def __post_init__(self):
1351
+ self._checked_upto = self.primes[-1]
1352
+
1353
+ @staticmethod
1354
+ def _is_prime(k: int, primes: list[int]) -> bool:
1355
+ if k < 2: return False
1356
+ for p in primes:
1357
+ if p * p > k: return True
1358
+ if k % p == 0: return False
1359
+ return True
1360
+
1361
+ def nth_prime(self, n: int) -> int:
1362
+ if n <= 0: raise ValueError("n must be >= 1")
1363
+ candidate = self._checked_upto
1364
+ if candidate % 2 == 0: candidate += 1
1365
+ while len(self.primes) < n:
1366
+ candidate += 2
1367
+ if self._is_prime(candidate, self.primes):
1368
+ self.primes.append(candidate)
1369
+ self._checked_upto = candidate
1370
+ return self.primes[n - 1]
1371
+
1372
+ @dataclass
1373
+ class HelicoidalZetaCore:
1374
+ zeta_dps: int = 25
1375
+ delta_modulus: int = 42
1376
+ delta_else: float = 0.42
1377
+ use_primes: bool = False
1378
+ _prime_cache: _PrimeCache = field(default_factory=_PrimeCache)
1379
+
1380
+ def __post_init__(self) -> None:
1381
+ self.phi = (1.0 + float(np.sqrt(5.0))) / 2.0
1382
+ if mp is not None:
1383
+ mp.dps = int(self.zeta_dps)
1384
+
1385
+ def _n_to_eval(self, n: int) -> int:
1386
+ if not self.use_primes: return int(n)
1387
+ return int(self._prime_cache.nth_prime(int(n)))
1388
+
1389
+ def Fn(self, n: int) -> float:
1390
+ nn = self._n_to_eval(n)
1391
+ return float(np.sin(2.0 * np.pi * self.phi * nn) ** 2)
1392
+
1393
+ def coords(self, n: int) -> np.ndarray:
1394
+ nn = self._n_to_eval(n)
1395
+ r = float(np.sin(2.0 * np.pi * self.phi * nn) ** 2)
1396
+ t = 2.0 * np.pi * self.phi * nn
1397
+ x = r * float(np.cos(t))
1398
+ y = r * float(np.sin(t))
1399
+ z = float(nn)
1400
+ return np.array([x, y, z], dtype=np.float32)
1401
+
1402
+ def delta_m(self, n: int, m: int | None = None) -> float:
1403
+ nn = self._n_to_eval(n)
1404
+ mm = int(self.delta_modulus if m is None else m)
1405
+ return 1.0 if (nn % mm) == 0 else float(self.delta_else)
1406
+
1407
+ def zeta_signature(self, n: int) -> np.ndarray:
1408
+ _require_mpmath()
1409
+ nn = self._n_to_eval(n)
1410
+ if mp is not None:
1411
+ mp.dps = int(self.zeta_dps)
1412
+ s = mp.mpc(0.5, float(nn))
1413
+ val = _mp_zeta(s)
1414
+ return np.array([float(val.real), float(val.imag)], dtype=np.float32)
1415
+ raise RuntimeError("mpmath not available")
1416
+
1417
+ def math_embedding(self, n: int) -> np.ndarray:
1418
+ nn = self._n_to_eval(n)
1419
+ c = self.coords(n)
1420
+ r = float(np.sin(2.0 * np.pi * self.phi * nn) ** 2)
1421
+ theta = 2.0 * np.pi * self.phi * nn
1422
+ delta = self.delta_m(n)
1423
+ zeta_vals = self.zeta_signature(n)
1424
+ return np.concatenate([c * delta, np.array([r, theta], dtype=np.float32), zeta_vals])
1425
+
1426
+ def transform(self, x: np.ndarray, n_val: int) -> np.ndarray:
1427
+ emb = self.math_embedding(n_val)
1428
+ raw_scale = float(np.mean(emb))
1429
+ # Versão conservadora OFFELLIA
1430
+ final_scale = min(0.78, 1.0 / (1.0 + abs(raw_scale) / 100.0))
1431
+ print(f"[OFFELLIA ZETA] n={n_val} | raw_mean={raw_scale:.6f} → final_scale={final_scale:.6f}")
1432
+ return x * final_scale
1433
+
1434
+ # --- FUNÇÃO GLOBAL (FORA DA CLASSE) ---
1435
+ def helicoidal_zeta_scale(n_val: int, *, use_primes: bool = False, zeta_dps: int = 25) -> float:
1436
+ core = HelicoidalZetaCore(use_primes=use_primes, zeta_dps=zeta_dps)
1437
+ emb = core.math_embedding(int(n_val))
1438
+ return float(np.mean(emb))
1439
+
1440
+ __all__ = [
1441
+ "HelicoidalZetaCore",
1442
+ "helicoidal_zeta_scale",
1443
+ ]