File size: 3,004 Bytes
fcca8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Tests for quantization utility functions."""

# pylint: disable=invalid-name

import torch
from bitsandbytes.functional import QuantState

from axolotl.kernels.quantize import dequantize


def test_dequantize_null_state():
    """Test that dequantize returns input unchanged when quant_state is None"""
    W = torch.randn(32, 32)
    assert torch.equal(dequantize(W, None), W)


def test_dequantize_shape_preservation():
    """Test that dequantization preserves expected shapes"""
    shape = (32, 32)
    W = torch.randn(shape, device="cuda")

    quant_state = QuantState(
        absmax=torch.ones(shape[0], device="cuda"),
        shape=shape,
        code=torch.randint(0, 15, shape, device="cuda"),
        dtype=torch.float16,
        blocksize=32,
        quant_type="nf4",
        offset=torch.zeros(shape[0], dtype=torch.int32, device="cuda"),
        state2=QuantState(
            absmax=torch.ones(shape[0], device="cuda"),
            shape=shape,
            code=torch.randint(0, 15, shape, device="cuda"),
            dtype=torch.float16,
            blocksize=32,
            quant_type="nf4",
            offset=None,
            state2=None,
        ),
    )

    result = dequantize(W, quant_state)
    assert result.shape == shape
    assert result.dtype == torch.float16
    assert result.device == W.device


def test_dequantize_transposed():
    """Test that transposed input produces transposed output"""
    shape = (32, 32)
    W = torch.randn(1, shape[1], device="cuda")  # Transposed input

    quant_state = QuantState(
        absmax=torch.ones(1),
        shape=shape,
        code=torch.randint(0, 15, shape),
        dtype=torch.float16,
        blocksize=32,
        quant_type="nf4",
        offset=torch.zeros(1, dtype=torch.int32),
        state2=QuantState(
            absmax=torch.ones(1),
            shape=shape,
            code=torch.randint(0, 15, shape),
            dtype=torch.float16,
            blocksize=32,
            quant_type="nf4",
            offset=None,
            state2=None,
        ),
    )

    result = dequantize(W, quant_state)
    assert result.shape[0] == shape[0]


def test_dequantize_output_tensor():
    """Test dequantization with provided output tensor"""
    shape = (32, 32)
    W = torch.randn(shape, device="cuda")
    out = torch.empty(shape, dtype=torch.float16, device="cuda")

    quant_state = QuantState(
        absmax=torch.ones(shape[0]),
        shape=shape,
        code=torch.randint(0, 15, shape),
        dtype=torch.float16,
        blocksize=32,
        quant_type="nf4",
        offset=torch.zeros(shape[0], dtype=torch.int32),
        state2=QuantState(
            absmax=torch.ones(shape[0]),
            shape=shape,
            code=torch.randint(0, 15, shape),
            dtype=torch.float16,
            blocksize=32,
            quant_type="nf4",
            offset=None,
            state2=None,
        ),
    )

    result = dequantize(W, quant_state, out=out)
    assert result is out