|
|
#ifdef THRUST_DEVICE_LOWER_BOUND_WORKS |
|
|
#include <thrust/binary_search.h> |
|
|
#include <thrust/device_vector.h> |
|
|
#include <thrust/execution_policy.h> |
|
|
#include <thrust/functional.h> |
|
|
#endif |
|
|
namespace c10 { |
|
|
namespace cuda { |
|
|
#ifdef THRUST_DEVICE_LOWER_BOUND_WORKS |
|
|
template <typename Iter, typename Scalar> |
|
|
__forceinline__ __device__ Iter |
|
|
lower_bound(Iter start, Iter end, Scalar value) { |
|
|
return thrust::lower_bound(thrust::device, start, end, value); |
|
|
} |
|
|
#else |
|
|
|
|
|
|
|
|
|
|
|
template <typename Iter, typename Scalar> |
|
|
__device__ Iter lower_bound(Iter start, Iter end, Scalar value) { |
|
|
while (start < end) { |
|
|
auto mid = start + ((end - start) >> 1); |
|
|
if (*mid < value) { |
|
|
start = mid + 1; |
|
|
} else { |
|
|
end = mid; |
|
|
} |
|
|
} |
|
|
return end; |
|
|
} |
|
|
#endif |
|
|
} |
|
|
} |
|
|
|