| """Tests for quantization utility functions.""" |
|
|
| |
|
|
| import torch |
| from bitsandbytes.functional import QuantState |
|
|
| from axolotl.kernels.quantize import dequantize |
|
|
|
|
| def test_dequantize_null_state(): |
| """Test that dequantize returns input unchanged when quant_state is None""" |
| W = torch.randn(32, 32) |
| assert torch.equal(dequantize(W, None), W) |
|
|
|
|
| def test_dequantize_shape_preservation(): |
| """Test that dequantization preserves expected shapes""" |
| shape = (32, 32) |
| W = torch.randn(shape, device="cuda") |
|
|
| quant_state = QuantState( |
| absmax=torch.ones(shape[0], device="cuda"), |
| shape=shape, |
| code=torch.randint(0, 15, shape, device="cuda"), |
| dtype=torch.float16, |
| blocksize=32, |
| quant_type="nf4", |
| offset=torch.zeros(shape[0], dtype=torch.int32, device="cuda"), |
| state2=QuantState( |
| absmax=torch.ones(shape[0], device="cuda"), |
| shape=shape, |
| code=torch.randint(0, 15, shape, device="cuda"), |
| dtype=torch.float16, |
| blocksize=32, |
| quant_type="nf4", |
| offset=None, |
| state2=None, |
| ), |
| ) |
|
|
| result = dequantize(W, quant_state) |
| assert result.shape == shape |
| assert result.dtype == torch.float16 |
| assert result.device == W.device |
|
|
|
|
| def test_dequantize_transposed(): |
| """Test that transposed input produces transposed output""" |
| shape = (32, 32) |
| W = torch.randn(1, shape[1], device="cuda") |
|
|
| quant_state = QuantState( |
| absmax=torch.ones(1), |
| shape=shape, |
| code=torch.randint(0, 15, shape), |
| dtype=torch.float16, |
| blocksize=32, |
| quant_type="nf4", |
| offset=torch.zeros(1, dtype=torch.int32), |
| state2=QuantState( |
| absmax=torch.ones(1), |
| shape=shape, |
| code=torch.randint(0, 15, shape), |
| dtype=torch.float16, |
| blocksize=32, |
| quant_type="nf4", |
| offset=None, |
| state2=None, |
| ), |
| ) |
|
|
| result = dequantize(W, quant_state) |
| assert result.shape[0] == shape[0] |
|
|
|
|
| def test_dequantize_output_tensor(): |
| """Test dequantization with provided output tensor""" |
| shape = (32, 32) |
| W = torch.randn(shape, device="cuda") |
| out = torch.empty(shape, dtype=torch.float16, device="cuda") |
|
|
| quant_state = QuantState( |
| absmax=torch.ones(shape[0]), |
| shape=shape, |
| code=torch.randint(0, 15, shape), |
| dtype=torch.float16, |
| blocksize=32, |
| quant_type="nf4", |
| offset=torch.zeros(shape[0], dtype=torch.int32), |
| state2=QuantState( |
| absmax=torch.ones(shape[0]), |
| shape=shape, |
| code=torch.randint(0, 15, shape), |
| dtype=torch.float16, |
| blocksize=32, |
| quant_type="nf4", |
| offset=None, |
| state2=None, |
| ), |
| ) |
|
|
| result = dequantize(W, quant_state, out=out) |
| assert result is out |
|
|