File size: 7,219 Bytes
346e086
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
from dataclasses import dataclass, fields
from typing import Type

import torch
from triton.tools.tensor_descriptor import TensorDescriptor

from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows
from .target_info import cuda_capability_geq
from .tensor_details.layout import Layout, StridedLayout


@dataclass
class Storage:
    data: torch.Tensor
    layout: Layout = None

    def __post_init__(self):
        assert isinstance(self.data, torch.Tensor)
        if self.layout is None:
            self.layout = StridedLayout(self.data.shape)

    @property
    def device(self):
        return self.data.device

    def is_tma_compliant(self):
        # TMAs didn't exist until Hopper
        if not cuda_capability_geq(9, 0):
            return False
        # TMAs only exist for 2D, 3D, 5D inputs
        if len(self.data.shape) not in [2, 3, 5]:
            return False
        # TMAs need at most one stride equal to 1
        # and all other strides divisble by 16
        strides = list(self.data.stride())
        try:
            major_dim = strides.index(1)
        except ValueError:
            major_dim = -1
        ndim = self.data.ndim
        bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8
        compliant = [strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim]
        return all(compliant)

    def make_tma(self, block_shape, transpose=False):
        strides = list(self.data.stride())
        shape = list(self.data.shape)
        # TODO
        # there is an issue w/ column-major TMA; we transpose instead
        transpose = self.data.stride()[-1] != 1
        if transpose:
            block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
            shape = shape[:-2] + [shape[-1], shape[-2]]
            strides = strides[:-2] + [strides[-1], strides[-2]]
        if self.data.dtype == torch.uint8 and self.layout.name is None:
            # physical block size is half logical block size along packed dimension
            indx = strides.index(1)
            block_shape[indx] = block_shape[indx] // 2
            # Pad the inner shape to 128 for mxfp4 weights; TMA requires this when the compiler uses
            # CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B.
            pad = 128
            shape[-1] = (shape[-1] + pad - 1) // pad * pad
        block_shape = self.layout.swizzle_block_shape(block_shape)
        return TensorDescriptor(self.data, shape, strides, block_shape)


@dataclass
class IntegerType:
    bitwidth: int


@dataclass
class FloatType:
    bitwidth_exponent: int
    bitwidth_mantissa: int
    is_signed: bool

    def __post_init__(self):
        self.bitwidth = int(self.is_signed) + self.bitwidth_exponent + self.bitwidth_mantissa


BIT = IntegerType(1)
FP4 = FloatType(bitwidth_exponent=2, bitwidth_mantissa=1, is_signed=True)


def bitwidth(type: IntegerType | FloatType | torch.dtype):
    if isinstance(type, torch.dtype):
        return type.itemsize * 8
    return type.bitwidth


@dataclass
class Tensor:
    storage: Storage | torch.Tensor
    dtype: IntegerType | FloatType | torch.dtype = None
    shape: list[int] | None = None
    shape_max: list[int] | None = None

    def __post_init__(self):
        # set storage
        if isinstance(self.storage, torch.Tensor):
            self.storage = Storage(self.storage)
        # initialize dtype
        if self.dtype is None:
            self.dtype = self.storage.data.dtype
        if bitwidth(self.dtype) < 8 and self.shape is None:
            raise ValueError("shape must be provided for sub-byte types")
        # initialize shape
        if self.shape is None:
            self.shape = list(self.storage.data.shape)
        # validate shape: all elements must be `int` or numel-1 `torch.Tensor`
        is_int = lambda s: isinstance(s, int)
        is_item = lambda s: hasattr(s, "numel") and s.numel() == 1
        assert all(map(lambda s: is_int(s) or is_item(s), self.shape))
        # initialize shape_max
        if self.shape_max is None:
            self.shape_max = [None] * len(self.shape)
        for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)):
            if smax is not None and not is_int(smax):
                raise ValueError(f"shape_max[{i}] must be `int` or `None`; got {type(smax)}")
            if smax is None:
                self.shape_max[i] = s
        # validate shape_max: all elements must be `int`
        assert all(map(is_int, self.shape_max))

    # torch compatibility layer
    @property
    def ndim(self):
        return len(self.shape)

    @property
    def device(self):
        return self.storage.device

    def stride(self, i=None):
        return self.storage.data.stride() if i is None else self.storage.data.stride(i)

    def data_ptr(self):
        return self.storage.data.data_ptr()

    def numel(self):
        return self.storage.data.numel()

    def element_size(self):
        return bitwidth(self.dtype) // 8

    @property
    def data(self):
        t = self.storage
        return t.data if isinstance(t, Storage) else t

    def dim(self):
        return self.ndim

    def size(self, i=None):
        if i is None:
            return self.shape
        return self.shape[i]


@dataclass
class Bitmatrix(Tensor):
    """
    Represents a boolean matrix in a packed format where each element occupies
    a single bit of memory.

    _scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
    with the actual bitmatrix to avoid having to launch a separate memset
    kernel when we call Bitmatrix::sum().
    """

    scratchpad: torch.Tensor = None

    def __init__(self, storage, shape, shape_max=None, scratchpad=None):
        super().__init__(storage, dtype=BIT, shape=shape, shape_max=shape_max)
        self.scratchpad = scratchpad

    def sum(self, partials_block_size):
        _, n_cols = self.shape
        dev = self.device
        if self.scratchpad is None:
            self.scratchpad = clear_sums(n_cols, dev)
        out_ret = self.scratchpad[:n_cols]
        self.scratchpad = None  # throw error if we try to sum again
        return sum_bitmatrix_rows(self, out_ret, partials_block_size)


def get_layout(tensor: torch.Tensor | Tensor | None):
    if tensor is None:
        return None
    if isinstance(tensor, Tensor):
        return tensor.storage.layout
    return StridedLayout


def wrap_torch_tensor(torch_tensor, dtype=None):
    if dtype is None:
        dtype = torch_tensor.dtype
    shape = list(torch_tensor.shape)
    shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(dtype)
    return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape)


def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs):
    assert isinstance(tensor, Tensor)
    old_storage = tensor.storage
    old_data = old_storage.layout.unswizzle_data(old_storage.data)
    new_layout = layout_cls(old_data.shape, **layout_kwargs)
    new_data = new_layout.swizzle_data(old_data)
    attrs = {k.name: getattr(tensor, k.name) for k in fields(tensor) if k.name != "storage"}
    return Tensor(Storage(new_data, new_layout), **attrs)