File size: 12,489 Bytes
c67ae40 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 | """
Test data generators for DeepGEMM kernel tests.
Adapted from the original DeepGEMM test suite to work with the ported
kernels-community package.
"""
import enum
import random
import torch
from typing import Generator, List, Optional, Tuple
from deep_gemm.testing import get_arch_major
from deep_gemm.utils import (
align, ceil_div,
per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8,
per_token_cast_to_fp4, transpose_packed_fp4,
get_mk_alignment_for_contiguous_layout
)
class KernelType(enum.Enum):
Kernel1D1D = 0
Kernel1D2D = 1
KernelNoSF = 2
def is_1d1d(self):
return self.value == 0
def is_1d2d(self):
return self.value == 1
def is_nosf(self):
return self.value == 2
class MajorTypeAB(enum.Enum):
KMajor = 0
MNMajor = 1
def is_k_major(self):
return self.value == 0
def is_mn_major(self):
return self.value == 1
class QuantConfig:
_legacy_quant_config = (128, 128, False, False)
def __init__(self, value: Tuple[int, int, bool, bool] = _legacy_quant_config):
self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b = value
def is_legacy(self) -> bool:
return (self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b) == self._legacy_quant_config
def get_recipes(self, is_wgrad: bool = False) -> Tuple[Tuple, Tuple, Tuple]:
recipe, recipe_a, recipe_b = None, None, None
if self.is_legacy():
recipe = (1, 1, 128) if is_wgrad else None
else:
recipe_a = (1, self.gran_k_a)
recipe_b = (1, self.gran_k_b) if self.is_fp4_b or is_wgrad else (self.gran_k_b, self.gran_k_b)
return recipe, recipe_a, recipe_b
def max_diff(self) -> float:
if self.is_fp4_a and self.is_fp4_b:
return 0.02
if self.is_fp4_a or self.is_fp4_b:
return 0.01
return 0.001
@staticmethod
def get_list_from_dtype(dtype: torch.dtype) -> List:
if dtype == torch.bfloat16:
return [None]
quant_config_list = [QuantConfig()]
if get_arch_major() == 10:
quant_config_list.append(QuantConfig((128, 32, False, True)))
return quant_config_list
def reset_seed(seed: int = 0):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def get_ue8m0_usage(kernel_type: KernelType) -> bool:
if get_arch_major() == 9:
return False
return kernel_type.is_1d1d()
def get_kernel_types(dtype: torch.dtype) -> tuple:
if dtype == torch.bfloat16:
return (KernelType.KernelNoSF, )
return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, )
def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator:
for major_a in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor):
for major_b in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor):
if major_a.is_mn_major() and not allow_a_mn_major:
continue
if major_b.is_mn_major() and not allow_b_mn_major:
continue
yield major_a, major_b
def get_psum_layout_usage() -> tuple:
return (False, True) if get_arch_major() == 10 else (False, )
def cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool,
use_ue8m0: bool, use_block_cast_for_fp8: bool = False):
if is_fp4:
x_fp4 = per_token_cast_to_fp4(x, use_ue8m0=use_ue8m0, gran_k=gran_k)
x = x_fp4 if major.is_k_major() else (transpose_packed_fp4(x_fp4[0]).T, x_fp4[1])
else:
x_fp8 = per_block_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \
else per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k)
x = x_fp8 if major.is_k_major() else (x_fp8[0].T.contiguous().T, x_fp8[1])
return x
def grouped_cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool,
use_ue8m0: bool, use_block_cast_for_fp8: bool = False):
num_groups, mn, k = x.size()
if is_fp4:
x_fp4 = (torch.empty((num_groups, mn, k // 2), device='cuda', dtype=torch.uint8) if major.is_k_major() else
torch.empty((num_groups, k, mn // 2), device='cuda', dtype=torch.uint8),
torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float))
for i in range(num_groups):
x_i_fp4 = per_token_cast_to_fp4(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k)
x_fp4[0][i], x_fp4[1][i] = x_i_fp4 if major.is_k_major() else (transpose_packed_fp4(x_i_fp4[0]), x_i_fp4[1])
x = x_fp4 if major.is_k_major() else (x_fp4[0].mT, x_fp4[1])
else:
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn),
torch.empty((num_groups, ceil_div(mn, gran_k), ceil_div(k, gran_k)), device='cuda', dtype=torch.float) if use_block_cast_for_fp8
else torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float))
for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = per_block_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \
else per_token_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k)
x = x_fp8 if major.is_k_major() else (x_fp8[0].mT.contiguous().mT, x_fp8[1])
return x
def generate_normal(m: int, n: int, k: int,
major_a: MajorTypeAB, major_b: MajorTypeAB,
accumulate: bool, out_dtype: torch.dtype,
kernel_type: KernelType,
use_ue8m0: bool = False, use_bf16: bool = False,
quant_config: Optional[QuantConfig] = None):
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \
torch.empty((m, n), device='cuda', dtype=out_dtype)
c = d if accumulate else None
ref_d = (a.float() @ b.float().t() + (c if accumulate else 0)).to(out_dtype)
if use_bf16:
a = a if major_a.is_k_major() else a.T.contiguous().T
b = b if major_b.is_k_major() else b.T.contiguous().T
return a, b, c, d, ref_d
quant_config = QuantConfig() if quant_config is None else quant_config
a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0)
b = cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0,
use_block_cast_for_fp8=not (kernel_type.is_1d1d() and accumulate))
return a, b, c, d, ref_d
def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int,
major_a: MajorTypeAB, major_b: MajorTypeAB,
use_ue8m0: bool = False, use_bf16: bool = False,
use_psum_layout: bool = False,
quant_config: Optional[QuantConfig] = None):
actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)]
aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms]
m = sum(aligned_ms)
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
grouped_layout = torch.empty(num_groups, device='cuda', dtype=torch.int32) if use_psum_layout \
else torch.empty(m, device='cuda', dtype=torch.int32)
d = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
start = 0
for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)):
actual_end = start + actual_m
aligned_end = start + aligned_m
if use_psum_layout:
grouped_layout[i] = actual_end
else:
grouped_layout[start: actual_end] = i
grouped_layout[actual_end: aligned_end] = -1
a[actual_end: aligned_end] = 0
ref_d[start: aligned_end] = a[start: aligned_end] @ b[i].t()
start = aligned_end
if use_bf16:
b = b if major_b.is_k_major() else b.mT.contiguous().mT
return m, a, b, grouped_layout, d, ref_d
assert major_a.is_k_major()
quant_config = QuantConfig() if quant_config is None else quant_config
a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0)
b = grouped_cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0,
use_block_cast_for_fp8=True)
return m, a, b, grouped_layout, d, ref_d
def layout_masked_to_psum(x: torch.Tensor, psum_m: torch.Tensor):
num_groups, max_m, _ = x.size()
x_psum = torch.empty_like(x).view(num_groups * max_m, -1)
last_psum_m = 0
for i in range(num_groups):
x_psum[last_psum_m: psum_m[i]] = x[i, :psum_m[i] - last_psum_m]
last_psum_m = align(psum_m[i], 128)
return x_psum
def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int,
use_ue8m0: bool = False, use_bf16: bool = False,
use_psum_layout: bool = False,
quant_config: Optional[QuantConfig] = None):
a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16)
b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16)
ref_d = torch.einsum('gmk,gnk->gmn', a, b)
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
psum_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
for j in range(num_groups):
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
psum_m[j] = (0 if j == 0 else align(psum_m[j - 1], 128)) + masked_m[j]
assert masked_m.amax().item() <= max_m
if use_bf16:
return a, b, masked_m, psum_m, d, ref_d
quant_config = QuantConfig() if quant_config is None else quant_config
a = grouped_cast_fp8_fp4_with_major(a, MajorTypeAB.KMajor, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0)
b = grouped_cast_fp8_fp4_with_major(b, MajorTypeAB.KMajor, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0,
use_block_cast_for_fp8=True)
return a, b, masked_m, psum_m, d, ref_d
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int,
major_a: MajorTypeAB, major_b: MajorTypeAB,
ks: List[int],
use_ue8m0: bool = False, use_bf16: bool = False):
assert get_mk_alignment_for_contiguous_layout() % 128 == 0
k = sum(ks)
a = torch.randn((k, m), device='cuda', dtype=torch.bfloat16)
b = torch.randn((k, n), device='cuda', dtype=torch.bfloat16)
c = torch.randn((num_groups, m, n), device='cuda', dtype=torch.float) * 32
d = c
ref_d = torch.empty_like(c)
start = 0
for i, group_k in enumerate(ks):
end = start + group_k
ref_d[i] = c[i] + (a[start:end].T @ b[start:end])
start = end
if use_bf16:
assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
return k, a, b, c, d, ref_d
a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0)
b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0)
if (major_a, major_b) == (MajorTypeAB.KMajor, MajorTypeAB.KMajor):
a, sfa = a_fp8
b, sfb = b_fp8
new_a = torch.empty((sum(ks) * m, ), dtype=a.dtype, device=a.device)
new_b = torch.empty((sum(ks) * n, ), dtype=b.dtype, device=b.device)
prefix = 0
for K in ks:
new_a[prefix * m : (prefix + K) * m] = a[prefix : prefix + K, ].T.flatten()
new_b[prefix * n : (prefix + K) * n] = b[prefix : prefix + K, ].T.flatten()
prefix += K
a_fp8, b_fp8 = (new_a, sfa.T), (new_b, sfb.T)
else:
assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
return k, a_fp8, b_fp8, c, d, ref_d
|