File size: 7,219 Bytes
346e086 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | from dataclasses import dataclass, fields
from typing import Type
import torch
from triton.tools.tensor_descriptor import TensorDescriptor
from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows
from .target_info import cuda_capability_geq
from .tensor_details.layout import Layout, StridedLayout
@dataclass
class Storage:
data: torch.Tensor
layout: Layout = None
def __post_init__(self):
assert isinstance(self.data, torch.Tensor)
if self.layout is None:
self.layout = StridedLayout(self.data.shape)
@property
def device(self):
return self.data.device
def is_tma_compliant(self):
# TMAs didn't exist until Hopper
if not cuda_capability_geq(9, 0):
return False
# TMAs only exist for 2D, 3D, 5D inputs
if len(self.data.shape) not in [2, 3, 5]:
return False
# TMAs need at most one stride equal to 1
# and all other strides divisble by 16
strides = list(self.data.stride())
try:
major_dim = strides.index(1)
except ValueError:
major_dim = -1
ndim = self.data.ndim
bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8
compliant = [strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim]
return all(compliant)
def make_tma(self, block_shape, transpose=False):
strides = list(self.data.stride())
shape = list(self.data.shape)
# TODO
# there is an issue w/ column-major TMA; we transpose instead
transpose = self.data.stride()[-1] != 1
if transpose:
block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
shape = shape[:-2] + [shape[-1], shape[-2]]
strides = strides[:-2] + [strides[-1], strides[-2]]
if self.data.dtype == torch.uint8 and self.layout.name is None:
# physical block size is half logical block size along packed dimension
indx = strides.index(1)
block_shape[indx] = block_shape[indx] // 2
# Pad the inner shape to 128 for mxfp4 weights; TMA requires this when the compiler uses
# CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B.
pad = 128
shape[-1] = (shape[-1] + pad - 1) // pad * pad
block_shape = self.layout.swizzle_block_shape(block_shape)
return TensorDescriptor(self.data, shape, strides, block_shape)
@dataclass
class IntegerType:
bitwidth: int
@dataclass
class FloatType:
bitwidth_exponent: int
bitwidth_mantissa: int
is_signed: bool
def __post_init__(self):
self.bitwidth = int(self.is_signed) + self.bitwidth_exponent + self.bitwidth_mantissa
BIT = IntegerType(1)
FP4 = FloatType(bitwidth_exponent=2, bitwidth_mantissa=1, is_signed=True)
def bitwidth(type: IntegerType | FloatType | torch.dtype):
if isinstance(type, torch.dtype):
return type.itemsize * 8
return type.bitwidth
@dataclass
class Tensor:
storage: Storage | torch.Tensor
dtype: IntegerType | FloatType | torch.dtype = None
shape: list[int] | None = None
shape_max: list[int] | None = None
def __post_init__(self):
# set storage
if isinstance(self.storage, torch.Tensor):
self.storage = Storage(self.storage)
# initialize dtype
if self.dtype is None:
self.dtype = self.storage.data.dtype
if bitwidth(self.dtype) < 8 and self.shape is None:
raise ValueError("shape must be provided for sub-byte types")
# initialize shape
if self.shape is None:
self.shape = list(self.storage.data.shape)
# validate shape: all elements must be `int` or numel-1 `torch.Tensor`
is_int = lambda s: isinstance(s, int)
is_item = lambda s: hasattr(s, "numel") and s.numel() == 1
assert all(map(lambda s: is_int(s) or is_item(s), self.shape))
# initialize shape_max
if self.shape_max is None:
self.shape_max = [None] * len(self.shape)
for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)):
if smax is not None and not is_int(smax):
raise ValueError(f"shape_max[{i}] must be `int` or `None`; got {type(smax)}")
if smax is None:
self.shape_max[i] = s
# validate shape_max: all elements must be `int`
assert all(map(is_int, self.shape_max))
# torch compatibility layer
@property
def ndim(self):
return len(self.shape)
@property
def device(self):
return self.storage.device
def stride(self, i=None):
return self.storage.data.stride() if i is None else self.storage.data.stride(i)
def data_ptr(self):
return self.storage.data.data_ptr()
def numel(self):
return self.storage.data.numel()
def element_size(self):
return bitwidth(self.dtype) // 8
@property
def data(self):
t = self.storage
return t.data if isinstance(t, Storage) else t
def dim(self):
return self.ndim
def size(self, i=None):
if i is None:
return self.shape
return self.shape[i]
@dataclass
class Bitmatrix(Tensor):
"""
Represents a boolean matrix in a packed format where each element occupies
a single bit of memory.
_scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
with the actual bitmatrix to avoid having to launch a separate memset
kernel when we call Bitmatrix::sum().
"""
scratchpad: torch.Tensor = None
def __init__(self, storage, shape, shape_max=None, scratchpad=None):
super().__init__(storage, dtype=BIT, shape=shape, shape_max=shape_max)
self.scratchpad = scratchpad
def sum(self, partials_block_size):
_, n_cols = self.shape
dev = self.device
if self.scratchpad is None:
self.scratchpad = clear_sums(n_cols, dev)
out_ret = self.scratchpad[:n_cols]
self.scratchpad = None # throw error if we try to sum again
return sum_bitmatrix_rows(self, out_ret, partials_block_size)
def get_layout(tensor: torch.Tensor | Tensor | None):
if tensor is None:
return None
if isinstance(tensor, Tensor):
return tensor.storage.layout
return StridedLayout
def wrap_torch_tensor(torch_tensor, dtype=None):
if dtype is None:
dtype = torch_tensor.dtype
shape = list(torch_tensor.shape)
shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(dtype)
return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape)
def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs):
assert isinstance(tensor, Tensor)
old_storage = tensor.storage
old_data = old_storage.layout.unswizzle_data(old_storage.data)
new_layout = layout_cls(old_data.shape, **layout_kwargs)
new_data = new_layout.swizzle_data(old_data)
attrs = {k.name: getattr(tensor, k.name) for k in fields(tensor) if k.name != "storage"}
return Tensor(Storage(new_data, new_layout), **attrs)
|