File size: 3,004 Bytes
fcca8c8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | """Tests for quantization utility functions."""
# pylint: disable=invalid-name
import torch
from bitsandbytes.functional import QuantState
from axolotl.kernels.quantize import dequantize
def test_dequantize_null_state():
"""Test that dequantize returns input unchanged when quant_state is None"""
W = torch.randn(32, 32)
assert torch.equal(dequantize(W, None), W)
def test_dequantize_shape_preservation():
"""Test that dequantization preserves expected shapes"""
shape = (32, 32)
W = torch.randn(shape, device="cuda")
quant_state = QuantState(
absmax=torch.ones(shape[0], device="cuda"),
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=torch.zeros(shape[0], dtype=torch.int32, device="cuda"),
state2=QuantState(
absmax=torch.ones(shape[0], device="cuda"),
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=None,
state2=None,
),
)
result = dequantize(W, quant_state)
assert result.shape == shape
assert result.dtype == torch.float16
assert result.device == W.device
def test_dequantize_transposed():
"""Test that transposed input produces transposed output"""
shape = (32, 32)
W = torch.randn(1, shape[1], device="cuda") # Transposed input
quant_state = QuantState(
absmax=torch.ones(1),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=torch.zeros(1, dtype=torch.int32),
state2=QuantState(
absmax=torch.ones(1),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=None,
state2=None,
),
)
result = dequantize(W, quant_state)
assert result.shape[0] == shape[0]
def test_dequantize_output_tensor():
"""Test dequantization with provided output tensor"""
shape = (32, 32)
W = torch.randn(shape, device="cuda")
out = torch.empty(shape, dtype=torch.float16, device="cuda")
quant_state = QuantState(
absmax=torch.ones(shape[0]),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=torch.zeros(shape[0], dtype=torch.int32),
state2=QuantState(
absmax=torch.ones(shape[0]),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=None,
state2=None,
),
)
result = dequantize(W, quant_state, out=out)
assert result is out
|