| | #include <thrust/device_vector.h> |
| | #include <thrust/scan.h> |
| | #include <thrust/sequence.h> |
| | #include <thrust/iterator/transform_iterator.h> |
| | #include <thrust/iterator/counting_iterator.h> |
| |
|
| | #include <assert.h> |
| |
|
| | |
| | |
| |
|
| | __host__ |
| | void scan_matrix_by_rows0(thrust::device_vector<int>& u, int n, int m) { |
| | |
| | |
| | |
| | for (int i = 0; i < n; ++i) |
| | thrust::inclusive_scan(u.begin() + m * i, u.begin() + m * (i + 1), |
| | u.begin() + m * i); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | struct which_row : thrust::unary_function<int, int> { |
| | int row_length; |
| |
|
| | __host__ __device__ |
| | which_row(int row_length_) : row_length(row_length_) {} |
| |
|
| | __host__ __device__ |
| | int operator()(int idx) const { |
| | return idx / row_length; |
| | } |
| | }; |
| |
|
| | __host__ |
| | void scan_matrix_by_rows1(thrust::device_vector<int>& u, int n, int m) { |
| | |
| | thrust::counting_iterator<int> c_first(0); |
| |
|
| | |
| | |
| | thrust::transform_iterator<which_row, thrust::counting_iterator<int> > |
| | t_first(c_first, which_row(m)); |
| |
|
| | |
| | |
| | thrust::inclusive_scan_by_key(t_first, t_first + n * m, u.begin(), u.begin()); |
| | } |
| |
|
| | int main() { |
| | int const n = 4; |
| | int const m = 5; |
| |
|
| | thrust::device_vector<int> u0(n * m); |
| | thrust::sequence(u0.begin(), u0.end()); |
| | scan_matrix_by_rows0(u0, n, m); |
| |
|
| | thrust::device_vector<int> u1(n * m); |
| | thrust::sequence(u1.begin(), u1.end()); |
| | scan_matrix_by_rows1(u1, n, m); |
| |
|
| | for (int i = 0; i < n; ++i) |
| | for (int j = 0; j < m; ++j) |
| | assert(u0[j + m * i] == u1[j + m * i]); |
| | } |
| |
|
| |
|