File size: 17,970 Bytes
c67ae40 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 | #include <torch/torch.h>
#include "apis/attention.hpp"
#include "apis/einsum.hpp"
#include "apis/gemm.hpp"
#include "apis/hyperconnection.hpp"
#include "apis/layout.hpp"
#if DG_TENSORMAP_COMPATIBLE
#include "jit/compiler.hpp"
#endif
#include "jit/device_runtime.hpp"
// ---- Type conversion helpers ----
static std::optional<std::tuple<int, int, int>>
vec_to_tuple3(const std::optional<std::vector<int64_t>>& v) {
if (!v.has_value()) return std::nullopt;
const auto& vec = v.value();
return std::make_tuple(static_cast<int>(vec[0]),
static_cast<int>(vec[1]),
static_cast<int>(vec[2]));
}
static std::tuple<int, int, int>
vec_to_tuple3_req(const std::vector<int64_t>& v) {
return std::make_tuple(static_cast<int>(v[0]),
static_cast<int>(v[1]),
static_cast<int>(v[2]));
}
static std::optional<std::tuple<int, int>>
vec_to_tuple2(const std::optional<std::vector<int64_t>>& v) {
if (!v.has_value()) return std::nullopt;
const auto& vec = v.value();
return std::make_tuple(static_cast<int>(vec[0]),
static_cast<int>(vec[1]));
}
static std::vector<int>
vec64_to_vec32(const std::vector<int64_t>& v) {
return std::vector<int>(v.begin(), v.end());
}
// ---- Runtime APIs ----
void dg_init(const std::string& library_root_path,
const std::string& cuda_home_path,
const std::string& cutlass_include_path) {
#if DG_TENSORMAP_COMPATIBLE
deep_gemm::Compiler::prepare_init(library_root_path, cuda_home_path, cutlass_include_path);
deep_gemm::KernelRuntime::prepare_init(cuda_home_path);
#endif
}
void dg_set_num_sms(int64_t new_num_sms) {
deep_gemm::device_runtime->set_num_sms(static_cast<int>(new_num_sms));
}
int64_t dg_get_num_sms() {
return static_cast<int64_t>(deep_gemm::device_runtime->get_num_sms());
}
void dg_set_tc_util(int64_t new_tc_util) {
deep_gemm::device_runtime->set_tc_util(static_cast<int>(new_tc_util));
}
int64_t dg_get_tc_util() {
return static_cast<int64_t>(deep_gemm::device_runtime->get_tc_util());
}
// ---- cuBLASLt GEMMs ----
void dg_cublaslt_gemm_nt(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c) {
deep_gemm::gemm::cublaslt_gemm_nt(a, b, d, c);
}
void dg_cublaslt_gemm_nn(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c) {
deep_gemm::gemm::cublaslt_gemm_nn(a, b, d, c);
}
void dg_cublaslt_gemm_tn(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c) {
deep_gemm::gemm::cublaslt_gemm_tn(a, b, d, c);
}
void dg_cublaslt_gemm_tt(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c) {
deep_gemm::gemm::cublaslt_gemm_tt(a, b, d, c);
}
// ---- FP8/FP4 GEMMs ----
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
void dg_fp8_fp4_gemm_nt(const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
deep_gemm::gemm::fp8_fp4_gemm_nt(
{a, sfa}, {b, sfb}, d, c,
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast);
}
void dg_fp8_fp4_gemm_nn(const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
deep_gemm::gemm::fp8_fp4_gemm_nn(
{a, sfa}, {b, sfb}, d, c,
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast);
}
void dg_fp8_fp4_gemm_tn(const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
deep_gemm::gemm::fp8_fp4_gemm_tn(
{a, sfa}, {b, sfb}, d, c,
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast);
}
void dg_fp8_fp4_gemm_tt(const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
deep_gemm::gemm::fp8_fp4_gemm_tt(
{a, sfa}, {b, sfb}, d, c,
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast);
}
void dg_m_grouped_fp8_fp4_gemm_nt_contiguous(
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d, const at::Tensor& grouped_layout,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast,
bool use_psum_layout,
const std::optional<int64_t>& expected_m_for_psum_layout) {
std::optional<int> expected_m;
if (expected_m_for_psum_layout.has_value())
expected_m = static_cast<int>(expected_m_for_psum_layout.value());
deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nt_contiguous(
{a, sfa}, {b, sfb}, d, grouped_layout,
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast, use_psum_layout, expected_m);
}
void dg_m_grouped_fp8_fp4_gemm_nn_contiguous(
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d, const at::Tensor& grouped_layout,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast,
bool use_psum_layout) {
deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nn_contiguous(
{a, sfa}, {b, sfb}, d, grouped_layout,
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast, use_psum_layout);
}
void dg_m_grouped_fp8_fp4_gemm_nt_masked(
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d, const at::Tensor& masked_m,
int64_t expected_m,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_a,
const std::optional<std::vector<int64_t>>& recipe_b,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nt_masked(
{a, sfa}, {b, sfb}, d, masked_m, static_cast<int>(expected_m),
vec_to_tuple3(recipe), vec_to_tuple2(recipe_a), vec_to_tuple2(recipe_b),
compiled_dims, disable_ue8m0_cast);
}
void dg_k_grouped_fp8_gemm_nt_contiguous(
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::vector<int64_t>& ks,
const at::Tensor& ks_tensor,
const std::optional<at::Tensor>& c,
const std::vector<int64_t>& recipe,
const std::string& compiled_dims) {
deep_gemm::gemm::k_grouped_fp8_gemm_nt_contiguous(
{a, sfa}, {b, sfb}, d, vec64_to_vec32(ks), ks_tensor, c,
vec_to_tuple3_req(recipe), compiled_dims);
}
void dg_k_grouped_fp8_gemm_tn_contiguous(
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::vector<int64_t>& ks,
const at::Tensor& ks_tensor,
const std::optional<at::Tensor>& c,
const std::vector<int64_t>& recipe,
const std::string& compiled_dims) {
deep_gemm::gemm::k_grouped_fp8_gemm_tn_contiguous(
{a, sfa}, {b, sfb}, d, vec64_to_vec32(ks), ks_tensor, c,
vec_to_tuple3_req(recipe), compiled_dims);
}
#endif // DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
// ---- BF16 GEMMs ----
#if DG_TENSORMAP_COMPATIBLE
void dg_bf16_gemm_nt(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::string& compiled_dims) {
deep_gemm::gemm::bf16_gemm_nt(a, b, d, c, compiled_dims);
}
void dg_bf16_gemm_nn(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::string& compiled_dims) {
deep_gemm::gemm::bf16_gemm_nn(a, b, d, c, compiled_dims);
}
void dg_bf16_gemm_tn(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::string& compiled_dims) {
deep_gemm::gemm::bf16_gemm_tn(a, b, d, c, compiled_dims);
}
void dg_bf16_gemm_tt(const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::string& compiled_dims) {
deep_gemm::gemm::bf16_gemm_tt(a, b, d, c, compiled_dims);
}
void dg_m_grouped_bf16_gemm_nt_contiguous(
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d, const at::Tensor& grouped_layout,
const std::string& compiled_dims,
bool use_psum_layout,
const std::optional<int64_t>& expected_m_for_psum_layout) {
std::optional<int> expected_m;
if (expected_m_for_psum_layout.has_value())
expected_m = static_cast<int>(expected_m_for_psum_layout.value());
deep_gemm::gemm::m_grouped_bf16_gemm_nt_contiguous(
a, b, d, grouped_layout, compiled_dims, use_psum_layout, expected_m);
}
void dg_m_grouped_bf16_gemm_nn_contiguous(
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d, const at::Tensor& grouped_layout,
const std::string& compiled_dims,
bool use_psum_layout) {
deep_gemm::gemm::m_grouped_bf16_gemm_nn_contiguous(
a, b, d, grouped_layout, compiled_dims, use_psum_layout);
}
void dg_m_grouped_bf16_gemm_nt_masked(
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d, const at::Tensor& masked_m,
int64_t expected_m,
const std::string& compiled_dims) {
deep_gemm::gemm::m_grouped_bf16_gemm_nt_masked(
a, b, d, masked_m, static_cast<int>(expected_m), compiled_dims);
}
void dg_k_grouped_bf16_gemm_tn_contiguous(
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::vector<int64_t>& ks,
const at::Tensor& ks_tensor,
const std::optional<at::Tensor>& c,
const std::string& compiled_dims) {
deep_gemm::gemm::k_grouped_bf16_gemm_tn_contiguous(
a, b, d, vec64_to_vec32(ks), ks_tensor, c, compiled_dims);
}
#endif // DG_TENSORMAP_COMPATIBLE
// ---- Einsum ----
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
void dg_einsum(const std::string& expr,
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
bool use_cublaslt) {
deep_gemm::einsum::einsum(expr, a, b, d, c, use_cublaslt);
}
void dg_fp8_einsum(const std::string& expr,
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::optional<at::Tensor>& c,
const std::vector<int64_t>& recipe) {
deep_gemm::einsum::fp8_einsum(
expr, {a, sfa}, {b, sfb}, d, c, vec_to_tuple3_req(recipe));
}
#endif // DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
// ---- Attention ----
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
void dg_fp8_gemm_nt_skip_head_mid(
const at::Tensor& a, const at::Tensor& sfa,
const at::Tensor& b, const at::Tensor& sfb,
const at::Tensor& d,
const std::vector<int64_t>& head_splits,
const std::optional<std::vector<int64_t>>& recipe,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
deep_gemm::attention::fp8_gemm_nt_skip_head_mid(
{a, sfa}, {b, sfb}, d, vec_to_tuple3_req(head_splits),
vec_to_tuple3(recipe), compiled_dims, disable_ue8m0_cast);
}
at::Tensor dg_fp8_mqa_logits(
const at::Tensor& q,
const at::Tensor& kv, const at::Tensor& kv_sf,
const at::Tensor& weights,
const at::Tensor& cu_seq_len_k_start,
const at::Tensor& cu_seq_len_k_end,
bool clean_logits,
int64_t max_seqlen_k) {
return deep_gemm::attention::fp8_mqa_logits(
q, {kv, kv_sf}, weights,
cu_seq_len_k_start, cu_seq_len_k_end,
clean_logits, static_cast<int>(max_seqlen_k));
}
at::Tensor dg_get_paged_mqa_logits_metadata(
const at::Tensor& context_lens,
int64_t block_kv, int64_t num_sms) {
return deep_gemm::attention::get_paged_mqa_logits_metadata(
context_lens, static_cast<int>(block_kv), static_cast<int>(num_sms));
}
at::Tensor dg_fp8_paged_mqa_logits(
const at::Tensor& q,
const at::Tensor& fused_kv_cache,
const at::Tensor& weights,
const at::Tensor& context_lens,
const at::Tensor& block_table,
const at::Tensor& schedule_meta,
int64_t max_context_len,
bool clean_logits) {
return deep_gemm::attention::fp8_paged_mqa_logits(
q, fused_kv_cache, weights, context_lens,
block_table, schedule_meta,
static_cast<int>(max_context_len), clean_logits);
}
#endif // DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
// ---- Hyperconnection ----
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
void dg_tf32_hc_prenorm_gemm(
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& d, const at::Tensor& sqr_sum,
const std::optional<int64_t>& num_splits) {
std::optional<int> ns;
if (num_splits.has_value())
ns = static_cast<int>(num_splits.value());
deep_gemm::hyperconnection::tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns);
}
#endif // DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
// ---- Layout ----
#if DG_TENSORMAP_COMPATIBLE
at::Tensor dg_transform_sf_into_required_layout(
const at::Tensor& sf,
int64_t mn, int64_t k,
const std::optional<std::vector<int64_t>>& recipe,
const std::optional<std::vector<int64_t>>& recipe_ab,
const std::optional<int64_t>& num_groups,
bool is_sfa, bool disable_ue8m0_cast) {
std::optional<int> ng;
if (num_groups.has_value())
ng = static_cast<int>(num_groups.value());
return deep_gemm::layout::transform_sf_into_required_layout(
sf, static_cast<int>(mn), static_cast<int>(k),
vec_to_tuple3(recipe), vec_to_tuple2(recipe_ab),
ng, is_sfa, disable_ue8m0_cast);
}
int64_t dg_get_tma_aligned_size(int64_t x, int64_t element_size) {
return static_cast<int64_t>(
deep_gemm::get_tma_aligned_size(
static_cast<int>(x), static_cast<int>(element_size)));
}
at::Tensor dg_get_mn_major_tma_aligned_tensor(const at::Tensor& sf) {
return deep_gemm::get_mn_major_tma_aligned_tensor(sf);
}
at::Tensor dg_get_mn_major_tma_aligned_packed_ue8m0_tensor(const at::Tensor& sf) {
return deep_gemm::get_mn_major_tma_aligned_packed_ue8m0_tensor(sf);
}
at::Tensor dg_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
const at::Tensor& sf,
const at::Tensor& ks_tensor,
const std::vector<int64_t>& ks) {
return deep_gemm::get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
sf, ks_tensor, vec64_to_vec32(ks));
}
#endif // DG_TENSORMAP_COMPATIBLE
int64_t dg_get_mk_alignment_for_contiguous_layout() {
return static_cast<int64_t>(deep_gemm::get_mk_alignment_for_contiguous_layout());
}
|