| namespace deep_gemm { | |
| // TODO: Use `torch::kFloat4_e2m1fn_x2` | |
| constexpr auto kPackedFP4 = torch::kUInt8; | |
| template <typename T> | |
| static T ceil_div(const T& a, const T& b) { | |
| return (a + b - 1) / b; | |
| } | |
| template <typename T> | |
| static constexpr T align(const T& a, const T& b) { | |
| return ceil_div(a, b) * b; | |
| } | |
| static int get_tma_aligned_size(const int& x, const int& element_size) { | |
| constexpr int kNumTMAAlignmentBytes = 16; | |
| DG_HOST_ASSERT(kNumTMAAlignmentBytes % element_size == 0); | |
| return align(x, kNumTMAAlignmentBytes / element_size); | |
| } | |
| } // namespace deep_gemm | |