File size: 429 Bytes
c1af2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#pragma once
#include <ATen/core/TensorBase.h>
#include <optional>

namespace at::cuda::detail {
TORCH_API void f8f8bf16_grouped_mm(

    at::Tensor mat_a, // FP8

    at::Tensor mat_b, // FP8

    at::Tensor scale_a, // FP32

    at::Tensor scale_b, // FP32

    std::optional<at::Tensor> offs,

    std::optional<at::Tensor> bias, // BF16

    bool use_fast_accum,

    at::Tensor& out);
} // namespace at::cuda::detail