File size: 4,476 Bytes
0dc1b04 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | /******************************************************************************
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/**
* \file
* Define helper math functions.
*/
#pragma once
#include <type_traits>
#include "util_namespace.cuh"
#include "util_macro.cuh"
CUB_NAMESPACE_BEGIN
namespace detail
{
template <typename T>
using is_integral_or_enum =
std::integral_constant<bool,
std::is_integral<T>::value || std::is_enum<T>::value>;
__host__ __device__ __forceinline__ constexpr std::size_t
VshmemSize(std::size_t max_shmem,
std::size_t shmem_per_block,
std::size_t num_blocks)
{
return shmem_per_block > max_shmem ? shmem_per_block * num_blocks : 0;
}
}
/**
* Divide n by d, round up if any remainder, and return the result.
*
* Effectively performs `(n + d - 1) / d`, but is robust against the case where
* `(n + d - 1)` would overflow.
*/
template <typename NumeratorT, typename DenominatorT>
__host__ __device__ __forceinline__ constexpr NumeratorT
DivideAndRoundUp(NumeratorT n, DenominatorT d)
{
static_assert(cub::detail::is_integral_or_enum<NumeratorT>::value &&
cub::detail::is_integral_or_enum<DenominatorT>::value,
"DivideAndRoundUp is only intended for integral types.");
// Static cast to undo integral promotion.
return static_cast<NumeratorT>(n / d + (n % d != 0 ? 1 : 0));
}
constexpr __device__ __host__ int
Nominal4BItemsToItemsCombined(int nominal_4b_items_per_thread, int combined_bytes)
{
return (cub::min)(nominal_4b_items_per_thread,
(cub::max)(1,
nominal_4b_items_per_thread * 8 /
combined_bytes));
}
template <typename T>
constexpr __device__ __host__ int
Nominal4BItemsToItems(int nominal_4b_items_per_thread)
{
return (cub::min)(nominal_4b_items_per_thread,
(cub::max)(1,
nominal_4b_items_per_thread * 4 /
static_cast<int>(sizeof(T))));
}
template <typename ItemT>
constexpr __device__ __host__ int
Nominal8BItemsToItems(int nominal_8b_items_per_thread)
{
return sizeof(ItemT) <= 8u
? nominal_8b_items_per_thread
: (cub::min)(nominal_8b_items_per_thread,
(cub::max)(1,
((nominal_8b_items_per_thread * 8) +
static_cast<int>(sizeof(ItemT)) - 1) /
static_cast<int>(sizeof(ItemT))));
}
/**
* \brief Computes the midpoint of the integers
*
* Extra operation is performed in order to prevent overflow.
*
* \return Half the sum of \p begin and \p end
*/
template <typename T>
constexpr __device__ __host__ T MidPoint(T begin, T end)
{
return begin + (end - begin) / 2;
}
CUB_NAMESPACE_END
|