Gemini
feat: add detailed logging
01d5a5d
#!/usr/bin/env python3
# Test gguf.quants so that it exactly matches the C implementation of the (de)quantization
# NOTE: this is kind of a mess, but at least it worked for initially testing the Python implementations.
from __future__ import annotations
import argparse
from math import prod
import os
import sys
from pathlib import Path
import ctypes
import logging
import numpy as np
# Necessary to load the local gguf package
if (
"NO_LOCAL_GGUF" not in os.environ
and (Path(__file__).parent.parent.parent / "gguf-py").exists()
):
sys.path.insert(0, str(Path(__file__).parent.parent))
import gguf
from gguf.constants import GGMLQuantizationType
logger = logging.getLogger("test-quants")
c_float_p = ctypes.POINTER(ctypes.c_float)
class ggml_init_params(ctypes.Structure):
_fields_ = [
("mem_size", ctypes.c_size_t),
("mem_buffer", ctypes.c_void_p),
("no_alloc", ctypes.c_bool),
]
class GGMLQuants:
libggml: ctypes.CDLL
def __init__(self, libggml: Path):
self.libggml = ctypes.CDLL(str(libggml))
self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t
# enum ggml_type type,
# const float * src,
# void * dst,
# int64_t start,
# int64_t nrows,
# int64_t n_per_row,
# const float * imatrix) {
self.libggml.ggml_quantize_chunk.argtypes = (
ctypes.c_int,
ctypes.POINTER(ctypes.c_float),
ctypes.c_void_p,
ctypes.c_int64,
ctypes.c_int64,
ctypes.c_int64,
ctypes.POINTER(ctypes.c_float),
)
self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool
self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,)
for t in (
"q4_0",
"q4_1",
"q5_0",
"q5_1",
"q8_0",
"q2_K",
"q3_K",
"q4_K",
"q5_K",
"q6_K",
"tq1_0",
"tq2_0",
"iq2_xxs",
"iq2_xs",
"iq2_s",
"iq3_xxs",
"iq3_s",
"iq1_s",
"iq1_m",
"iq4_nl",
"iq4_xs",
):
dequant_func: ctypes._NamedFuncPointer = getattr(
self.libggml, "dequantize_row_" + t
)
dequant_func.restype = None
dequant_func.argtypes = (
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_float),
ctypes.c_int64,
)
self.libggml.ggml_fp16_to_fp32_row.restype = None
self.libggml.ggml_fp16_to_fp32_row.argtypes = (
ctypes.POINTER(ctypes.c_uint16),
ctypes.POINTER(ctypes.c_float),
ctypes.c_int64,
)
self.libggml.ggml_bf16_to_fp32_row.restype = None
self.libggml.ggml_bf16_to_fp32_row.argtypes = (
ctypes.POINTER(ctypes.c_uint16),
ctypes.POINTER(ctypes.c_float),
ctypes.c_int64,
)
self.libggml.ggml_init.argtypes = (ggml_init_params,)
self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False))
def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
result = np.zeros(
gguf.quant_shape_from_byte_shape(tensor.shape, qtype),
dtype=np.float32,
order="C",
)
if qtype == GGMLQuantizationType.F32:
# no-op
result = tensor.view(np.float32)
elif qtype == GGMLQuantizationType.F16:
self.libggml.ggml_fp16_to_fp32_row(
tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)),
result.ctypes.data_as(c_float_p),
result.size,
)
elif qtype == GGMLQuantizationType.BF16:
self.libggml.ggml_bf16_to_fp32_row(
tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)),
result.ctypes.data_as(c_float_p),
result.size,
)
else:
lw_qname = qtype.name.lower()
if lw_qname[-1] == "k":
lw_qname = lw_qname[:-1] + "K"
dequant_func: ctypes._NamedFuncPointer = getattr(
self.libggml, "dequantize_row_" + lw_qname
)
dequant_func(
tensor.ctypes.data_as(ctypes.c_void_p),
result.ctypes.data_as(c_float_p),
result.size,
)
return result
def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
result = np.zeros(
gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C"
)
if self.libggml.ggml_quantize_requires_imatrix(qtype.value):
# TODO: is a column-wise sum of squares appropriate?
qw = np.sum(
(data * data).reshape((-1, data.shape[-1])), axis=0
).ctypes.data_as(c_float_p)
else:
qw = ctypes.cast(0, c_float_p)
result_size = self.libggml.ggml_quantize_chunk(
qtype.value,
data.ctypes.data_as(c_float_p),
result.ctypes.data_as(ctypes.c_void_p),
0,
prod(data.shape[:-1]),
data.shape[-1],
qw,
)
assert result.size == result_size
return result
def compare_tensors(
t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType
) -> bool:
same = np.array_equal(t1, t2)
if same:
return True
else:
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
if t1.dtype == np.float32:
t1 = t1.reshape((-1, block_size))
t2 = t2.reshape((-1, block_size))
else:
t1 = t1.reshape((-1, type_size))
t2 = t2.reshape((-1, type_size))
x = t1.view(np.uint8) ^ t2.view(np.uint8)
diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1)
num_bad_blocks = np.count_nonzero(diff_bits, axis=0)
if num_bad_blocks == 0 and t1.shape == t2.shape:
logger.debug("Bits are equal, but arrays don't match, likely contains NANs")
return True
logger.debug(
f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)"
)
bad_block_id = np.argmax(diff_bits, axis=0)
logger.debug(f"Worst block id: {bad_block_id}")
logger.debug(
f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}"
)
sum_diff_bits = np.sum(diff_bits)
logger.debug(
f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)"
)
return False
def do_test(libggml_path: Path, quick: bool = False):
ggml_quants = GGMLQuants(libggml_path)
np.set_printoptions(
precision=None,
threshold=(4 * 256) + 1,
formatter={"int": lambda n: "0x%02X" % n},
)
r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)
for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()):
has_dequantize = False
has_quantize = False
try:
gguf.dequantize(
np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype
)
has_dequantize = True
except (NotImplementedError, AssertionError) as e:
if isinstance(e, AssertionError):
logger.error(f"Error with {qtype.name}: {e}")
raise e
try:
gguf.quantize(
np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype
)
has_quantize = True
except (NotImplementedError, AssertionError) as e:
if isinstance(e, AssertionError):
logger.error(f"Error with {qtype.name}: {e}")
raise e
if not has_dequantize and not has_quantize:
continue
logger.info(f"Testing {qtype.name}")
rc = r.copy(order="C")
pyq = None
ggq = None
if has_quantize:
logger.debug(f"Quantizing to {qtype.name} with Python")
pyq = gguf.quants.quantize(rc, qtype)
logger.debug(f"Quantizing to {qtype.name} with C")
ggq = ggml_quants.quantize(rc, qtype)
if qtype == GGMLQuantizationType.F16:
pyq = pyq.view(np.uint8)
quant_equal = compare_tensors(pyq, ggq, qtype)
if not quant_equal:
logger.error(f"Quantization to {qtype.name} does not match ❌")
else:
logger.info(f"Quantization to {qtype.name} matches exactly βœ…")
if has_dequantize:
if ggq is None and not quick:
logger.debug(f"Quantizing to {qtype.name} with C")
ggq = ggml_quants.quantize(rc, qtype)
if ggq is not None:
logger.debug(f"Dequantizing from {qtype.name} with Python")
pydq = gguf.quants.dequantize(ggq, qtype)
logger.debug(f"Dequantizing from {qtype.name} with C")
ggdq = ggml_quants.dequantize(ggq, qtype)
dequant_equal = compare_tensors(pydq, ggdq, qtype)
if not dequant_equal:
logger.error(f"Dequantization from {qtype.name} does not match ❌")
else:
logger.info(f"Dequantization from {qtype.name} matches exactly βœ…")
rq_shape = gguf.quants.quant_shape_to_byte_shape(
(8, 1024, 1024 // 2), qtype
)
rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8)
logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python")
pydq = gguf.quants.dequantize(rq, qtype)
logger.debug(f"Dequantizing random f16 data as {qtype.name} with C")
ggdq = ggml_quants.dequantize(rq, qtype)
dequant_equal = compare_tensors(pydq, ggdq, qtype)
if not dequant_equal:
logger.error(
f"Dequantization from random f16 data as {qtype.name} does not match ❌"
)
else:
logger.info(
f"Dequantization from random f16 data as {qtype.name} matches exactly βœ…"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test Python (de)quantization against the reference C implementation"
)
parser.add_argument(
"--libggml",
type=Path,
default=Path(__file__).parent.parent.parent
/ "build"
/ "ggml"
/ "src"
/ "libggml.so",
help="The path to libggml.so",
)
parser.add_argument(
"--quick",
action="store_true",
help="Don't quantize with C when it's not strictly necessary",
)
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG)
do_test(args.libggml, args.quick)