| |
| |
| |
|
|
| #pragma once |
|
|
| #include "cute/tensor.hpp" |
|
|
| #include <cutlass/cutlass.h> |
| #include <cutlass/array.h> |
| #include <cutlass/numeric_types.h> |
| #include <cutlass/kernel_hardware_info.h> |
|
|
| #include "utils.h" |
|
|
| namespace flash { |
|
|
| using namespace cute; |
|
|
| template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_> |
| class FlashAttnBwdSm80 { |
|
|
| public: |
|
|
| |
| static constexpr bool Is_causal = CollectiveMainloop_::Is_causal; |
| static constexpr bool Is_local = CollectiveMainloop_::Is_local; |
| static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen); |
| static constexpr bool Varlen = CollectiveMainloop_::Varlen; |
|
|
| |
| using CollectiveMainloop = CollectiveMainloop_; |
| using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; |
| using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP; |
| using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV; |
| using ArchTag = typename CollectiveMainloop::ArchTag; |
| using MainloopArguments = typename CollectiveMainloop::Arguments; |
| using MainloopParams = typename CollectiveMainloop::Params; |
| static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB; |
|
|
| |
| using CollectiveEpilogue = CollectiveEpilogue_; |
| using EpilogueArguments = typename CollectiveEpilogue::Arguments; |
| using EpilogueParams = typename CollectiveEpilogue::Params; |
|
|
| static_assert(ArchTag::kMinComputeCapability >= 80); |
|
|
| using TileScheduler = TileScheduler_; |
| using TileSchedulerArguments = typename flash::TileSchedulerArguments; |
| using TileSchedulerParams = typename TileScheduler::Params; |
|
|
| static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{})); |
| static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})); |
| static constexpr uint32_t MinBlocksPerMultiprocessor = 1; |
|
|
| |
| struct SharedStorage { |
| struct TensorStorage : cute::aligned_struct<128> { |
| union { |
| typename CollectiveMainloop::TensorStorage mainloop; |
| typename CollectiveEpilogue::TensorStorage epilogue; |
| }; |
| } tensors; |
|
|
| alignas(16) typename TileScheduler::SharedStorage smem_scheduler; |
|
|
| }; |
|
|
| static constexpr int SharedStorageSize = sizeof(SharedStorage); |
|
|
| |
| struct Arguments { |
| MainloopArguments mainloop{}; |
| EpilogueArguments epilogue{}; |
| cutlass::KernelHardwareInfo hw_info{}; |
| TileSchedulerArguments scheduler{}; |
| }; |
|
|
| |
| struct Params { |
| MainloopParams mainloop{}; |
| EpilogueParams epilogue{}; |
| cutlass::KernelHardwareInfo hw_info{}; |
| TileSchedulerParams scheduler{}; |
| }; |
|
|
| |
| |
| |
|
|
| |
| static |
| Params |
| to_underlying_arguments(Arguments const& args) { |
| CUTLASS_TRACE_HOST("to_underlying_arguments():"); |
|
|
| |
| int sm_count = args.hw_info.sm_count; |
| if (sm_count <= 0) { |
| CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" |
| " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); |
| sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); |
| } |
|
|
| CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); |
|
|
| cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; |
| return { |
| CollectiveMainloop::to_underlying_arguments(args.mainloop), |
| CollectiveEpilogue::to_underlying_arguments(args.epilogue), |
| hw_info, |
| TileScheduler::to_underlying_arguments(args.scheduler) |
| }; |
| } |
|
|
| |
| static dim3 |
| get_grid_shape(Params const& params) { |
| return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); |
| } |
|
|
| static dim3 |
| get_block_shape() { |
| return dim3(MaxThreadsPerBlock, 1, 1); |
| } |
|
|
| CUTLASS_DEVICE |
| void |
| operator()(Params const& params, char* smem_buf) { |
|
|
| static constexpr int kBlockM = get<0>(TileShape_MNK{}); |
| static constexpr int kBlockN = get<1>(TileShape_MNK{}); |
|
|
| SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf); |
|
|
| CollectiveMainloop mainloop; |
| CollectiveEpilogue epilogue; |
|
|
| TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler)); |
| |
| TiledMmadKV tiled_mma_dKV; |
|
|
| scheduler.init_consumer(); |
|
|
| int warp_idx = cutlass::canonical_warp_idx_sync(); |
| CUTLASS_PRAGMA_NO_UNROLL |
| for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work<true>(params.scheduler) : scheduler.template get_initial_work<false>(params.scheduler); |
| work_tile_info.is_valid(params.scheduler); |
| work_tile_info = warp_idx == 0 ? scheduler.template get_next_work<true>(params.scheduler, work_tile_info) : scheduler.template get_next_work<false>(params.scheduler, work_tile_info)) { |
|
|
| auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); |
| auto [n_block, bidh, bidb, _ ] = block_coord_; |
| cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb}; |
|
|
| |
| Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{})); |
| Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{})); |
| bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x, |
| block_coord, shared_storage); |
| scheduler.prefetch_next_work(params.scheduler, work_tile_info); |
| if (tile_valid) { |
| epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, |
| threadIdx.x, block_coord); |
| } else { |
| epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); |
| } |
| } |
|
|
| } |
|
|
| }; |
|
|
| } |
|
|