File size: 13,327 Bytes
a6dd040 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
/* Copyright (c) 2020 NVIDIA CORPORATION.
* Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu)
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*
* Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
* Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
* of the code.
*/
#ifndef COORDINATE_MAP_GPU_CUH
#define COORDINATE_MAP_GPU_CUH
#include "3rdparty/concurrent_unordered_map.cuh"
#include "3rdparty/hash/helper_functions.cuh"
#include "allocators.cuh"
#include "coordinate_map.hpp"
#include "coordinate_map_functors.cuh"
#include "kernel_map.cuh"
#include "storage.cuh"
#include <torch/extension.h>
namespace minkowski {
/*
* Inherit from the CoordinateMap for a concurrent coordinate unordered map.
*/
template <typename coordinate_type, template <typename T>
class TemplatedAllocator =
detail::c10_allocator>
class CoordinateMapGPU
: public CoordinateMap<coordinate_type, TemplatedAllocator> {
public:
// clang-format off
using base_type = CoordinateMap<coordinate_type, TemplatedAllocator>;
using self_type = CoordinateMapGPU<coordinate_type, TemplatedAllocator>;
using size_type = typename base_type::size_type;
using index_type = typename base_type::index_type;
using stride_type = typename base_type::stride_type;
using map_allocator_type = TemplatedAllocator<thrust::pair<coordinate<coordinate_type>, index_type>>;
using byte_allocator_type = TemplatedAllocator<char>;
// Map types
using key_type = coordinate<coordinate_type>;
using mapped_type = index_type;
using hasher_type = detail::coordinate_murmur3<coordinate_type>;
using key_equal_type = detail::coordinate_equal_to<coordinate_type>;
using map_type = concurrent_unordered_map<key_type, // key
mapped_type, // mapped_type
hasher_type, // hasher
key_equal_type, // equality
map_allocator_type>;// allocator
using value_type = typename map_type::value_type;
// return types
using kernel_map_type = gpu_kernel_map<index_type, byte_allocator_type>;
// iterator
using iterator = typename map_type::iterator;
using const_iterator = typename map_type::const_iterator;
// index vectors
using index_storage_type = gpu_storage<default_types::index_type, byte_allocator_type>;
using coordnate_storage_type = gpu_storage<coordinate_type, byte_allocator_type>;
// clang-format on
// return types
// using the QueryResultAllocator gives segfault!
public:
CoordinateMapGPU() = delete;
CoordinateMapGPU(size_type const number_of_coordinates,
size_type const coordinate_size,
size_type const hashtable_occupancy = 50,
stride_type const stride = {1},
map_allocator_type map_alloc = map_allocator_type(),
byte_allocator_type byte_alloc = byte_allocator_type())
: base_type(number_of_coordinates, coordinate_size, stride, byte_alloc),
m_hashtable_occupancy{hashtable_occupancy},
m_capacity(0), // should be updated in the reserve
m_hasher(hasher_type{coordinate_size}),
m_equal(key_equal_type{coordinate_size}),
m_unused_key(coordinate<coordinate_type>{nullptr}),
m_unused_element(std::numeric_limits<coordinate_type>::max()),
m_map_allocator(map_alloc) {
// reserve coordinates
reserve(number_of_coordinates);
// copy the tensor_stride
LOG_DEBUG("tensor_stride", base_type::m_tensor_stride);
m_device_tensor_stride.from_vector(base_type::m_tensor_stride);
// m_device_tensor_stride = base_type::m_tensor_stride;
LOG_DEBUG("device tensor_stride set");
static_assert(
sizeof(index_type) == sizeof(size_type),
"kernel_map shared memory requires the type sizes to be the same");
static_assert(
sizeof(coordinate_type) == sizeof(size_type),
"kernel_map shared memory requires the type sizes to be the same");
}
template <bool remap>
void insert(coordinate_iterator<coordinate_type> key_first,
coordinate_iterator<coordinate_type> key_last);
template <bool remap>
std::pair<index_storage_type, index_storage_type>
insert_and_map(coordinate_iterator<coordinate_type> key_first,
coordinate_iterator<coordinate_type> key_last);
std::pair<index_storage_type, index_storage_type>
find(coordinate_iterator<coordinate_type> key_first,
coordinate_iterator<coordinate_type> key_last) const;
inline void reserve(size_type size) {
if (size > m_capacity) {
// reserve coordinate
base_type::reserve(size);
// reserve map
LOG_DEBUG("Reserve map of",
compute_hash_table_size(size, m_hashtable_occupancy),
"for concurrent_unordered_map of size", size);
m_map = map_type::create(
compute_hash_table_size(size, m_hashtable_occupancy),
m_unused_element, m_unused_key, m_hasher, m_equal, m_map_allocator);
LOG_DEBUG("Done concurrent_unordered_map creation");
CUDA_TRY(cudaStreamSynchronize(0));
m_capacity = size;
LOG_DEBUG("Reserved concurrent_unordered_map");
}
}
// Network specific functions.
/*
* @brief strided coordinate map.
*/
self_type stride(stride_type const &stride) const;
self_type stride_region(cpu_kernel_region<coordinate_type> &kernel,
stride_type const &out_tensor_stride) const;
self_type origin() const;
self_type prune(bool const *keep_begin, bool const *keep_end) const;
self_type
merge(std::vector<std::reference_wrapper<self_type>> const &maps) const;
kernel_map_type kernel_map(self_type const &out_coordinate_map,
gpu_kernel_region<coordinate_type> const &kernel,
CUDAKernelMapMode::Mode kernel_map_mode,
uint32_t thread_dim = CUDA_NUM_THREADS) const;
kernel_map_type stride_map(self_type const &out_coordinate_map,
stride_type const &out_tensor_stride,
uint32_t thread_dim = CUDA_NUM_THREADS) const;
kernel_map_type origin_map(self_type const &origin_coordinate_map,
uint32_t thread_dim = CUDA_NUM_THREADS) const;
std::vector<at::Tensor>
interpolation_map_weight(at::Tensor const &tfield) const;
template <typename coordinate_field_type>
std::pair<at::Tensor, at::Tensor>
field_map(coordinate_field_type const *p_tfield,
size_type const num_tfield) const;
std::vector<at::Tensor>
union_map(std::vector<std::reference_wrapper<self_type>> const &maps,
uint32_t thread_dim = CUDA_NUM_THREADS) const;
// Returns the number of elements in the coordinate map
inline size_type size() const { return m_size; }
void copy_coordinates(coordinate_type *dst_coordinate) const;
std::string to_string() const {
Formatter o;
o << "CoordinateMapGPU:" << size() << "x" << m_coordinate_size;
return o.str();
}
inline map_type const const_hash_map() const { return *m_map.get(); };
// Insert indices given initialized coordinates
void initialize_valid_indices(size_t const N_unique);
// access the coordinate data pointer
using base_type::const_coordinate_data;
using base_type::coordinate_data;
using base_type::coordinate_size;
// Find GPU values in the map. key_iterator must be a GPU iterator.
// template <typename key_iterator>
// std::pair<device_index_vector_type, device_index_vector_type>
// find(key_iterator key_first, key_iterator key_last);
private:
using base_type::m_byte_allocator;
using base_type::m_coordinate_size;
size_type m_hashtable_occupancy;
size_type m_capacity;
size_type m_size;
hasher_type const m_hasher;
key_equal_type const m_equal;
key_type const m_unused_key;
mapped_type const m_unused_element;
index_storage_type m_valid_row_index;
index_storage_type m_valid_map_index;
index_storage_type m_inverse_row_index;
index_storage_type m_device_tensor_stride;
map_allocator_type m_map_allocator;
std::shared_ptr<map_type> m_map;
};
template <typename coordinate_field_type, typename coordinate_int_type,
template <typename T> class TemplatedAllocator =
detail::c10_allocator>
class CoordinateFieldMapGPU
: public CoordinateMap<coordinate_field_type, TemplatedAllocator> {
// Coordinate wrapper
public:
using base_type = CoordinateMap<coordinate_field_type, TemplatedAllocator>;
using coordinate_map_type =
CoordinateMapGPU<coordinate_int_type, TemplatedAllocator>;
using self_type =
CoordinateFieldMapGPU<coordinate_field_type, coordinate_int_type,
TemplatedAllocator>;
using size_type = typename base_type::size_type;
using index_type = typename base_type::index_type;
using stride_type = typename base_type::stride_type;
using byte_allocator_type = TemplatedAllocator<char>;
using map_allocator_type = TemplatedAllocator<
thrust::pair<coordinate<coordinate_int_type>, index_type>>;
using kernel_map_type = gpu_kernel_map<index_type, byte_allocator_type>;
// Map types
using key_type = coordinate<coordinate_int_type>;
using mapped_type = index_type;
using hasher_type = detail::coordinate_murmur3<coordinate_int_type>;
using key_equal_type = detail::coordinate_equal_to<coordinate_int_type>;
using int_hash_map_type =
concurrent_unordered_map<key_type, // key
mapped_type, // mapped_type
hasher_type, // hasher
key_equal_type, // equality
map_allocator_type>; // allocator
public:
CoordinateFieldMapGPU() = delete;
CoordinateFieldMapGPU(size_type const number_of_coordinates,
size_type const coordinate_size,
stride_type const &stride = {1},
byte_allocator_type alloc = byte_allocator_type())
: base_type(number_of_coordinates, coordinate_size, stride, alloc),
m_size(number_of_coordinates) {
base_type::reserve(number_of_coordinates);
}
/*
* @brief given a key iterator begin-end pair and a value iterator begin-end
* pair, insert all elements.
*
* @return none
*/
void insert(coordinate_field_type const *coordinate_begin,
coordinate_field_type const *coordinate_end) {
size_type N = (coordinate_end - coordinate_begin) / m_coordinate_size;
base_type::allocate(N);
// copy data directly to the ptr
CUDA_CHECK(cudaMemcpy(base_type::coordinate_data(), coordinate_begin,
N * m_coordinate_size * sizeof(coordinate_field_type),
cudaMemcpyDeviceToDevice));
}
void copy_coordinates(coordinate_field_type *dst_coordinate) const {
CUDA_CHECK(
cudaMemcpy(dst_coordinate, base_type::const_coordinate_data(),
size() * m_coordinate_size * sizeof(coordinate_field_type),
cudaMemcpyDeviceToDevice));
}
void quantize_coordinates(coordinate_int_type *p_dst_coordinates,
stride_type const &tensor_stride) const;
using base_type::const_coordinate_data;
using base_type::coordinate_data;
coordinate_map_type origin() const;
kernel_map_type origin_map(coordinate_map_type const &origin_coordinate_map,
uint32_t thread_dim = CUDA_NUM_THREADS) const;
inline size_type size() const noexcept { return m_size; }
std::string to_string() const {
Formatter o;
o << "CoordinateFieldMapGPU:" << size() << "x" << m_coordinate_size;
return o.str();
}
private:
using base_type::m_byte_allocator;
using base_type::m_coordinate_size;
size_type m_size;
};
} // namespace minkowski
#endif // COORDINATE_MAP_GPU_CUH
|