Spaces:
Build error
Build error
| /* | |
| * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| * SPDX-License-Identifier: Apache-2.0 | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| /** @file common_host.cu | |
| * @author Thomas Müller and Nikolaus Binder, NVIDIA | |
| * @brief Common utilities that are needed by pretty much every component of this framework. | |
| */ | |
| namespace tcnn { | |
| static_assert( | |
| __CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2), | |
| "tiny-cuda-nn requires at least CUDA 10.2" | |
| ); | |
| std::function<void(LogSeverity, const std::string&)> g_log_callback = [](LogSeverity severity, const std::string& msg) { | |
| switch (severity) { | |
| case LogSeverity::Warning: std::cerr << fmt::format("tiny-cuda-nn warning: {}\n", msg); break; | |
| case LogSeverity::Error: std::cerr << fmt::format("tiny-cuda-nn error: {}\n", msg); break; | |
| default: break; | |
| } | |
| if (verbose()) { | |
| switch (severity) { | |
| case LogSeverity::Debug: std::cerr << fmt::format("tiny-cuda-nn debug: {}\n", msg); break; | |
| case LogSeverity::Info: std::cerr << fmt::format("tiny-cuda-nn info: {}\n", msg); break; | |
| case LogSeverity::Success: std::cerr << fmt::format("tiny-cuda-nn success: {}\n", msg); break; | |
| default: break; | |
| } | |
| } | |
| }; | |
| const std::function<void(LogSeverity, const std::string&)>& log_callback() { return g_log_callback; } | |
| void set_log_callback(const std::function<void(LogSeverity, const std::string&)>& cb) { g_log_callback = cb; } | |
| bool g_verbose = false; | |
| bool verbose() { return g_verbose; } | |
| void set_verbose(bool verbose) { g_verbose = verbose; } | |
| Activation string_to_activation(const std::string& activation_name) { | |
| if (equals_case_insensitive(activation_name, "None")) { | |
| return Activation::None; | |
| } else if (equals_case_insensitive(activation_name, "ReLU")) { | |
| return Activation::ReLU; | |
| } else if (equals_case_insensitive(activation_name, "LeakyReLU")) { | |
| return Activation::LeakyReLU; | |
| } else if (equals_case_insensitive(activation_name, "Exponential")) { | |
| return Activation::Exponential; | |
| } else if (equals_case_insensitive(activation_name, "Sigmoid")) { | |
| return Activation::Sigmoid; | |
| } else if (equals_case_insensitive(activation_name, "Sine")) { | |
| return Activation::Sine; | |
| } else if (equals_case_insensitive(activation_name, "Squareplus")) { | |
| return Activation::Squareplus; | |
| } else if (equals_case_insensitive(activation_name, "Softplus")) { | |
| return Activation::Softplus; | |
| } else if (equals_case_insensitive(activation_name, "Tanh")) { | |
| return Activation::Tanh; | |
| } | |
| throw std::runtime_error{fmt::format("Invalid activation name: {}", activation_name)}; | |
| } | |
| std::string to_string(Activation activation) { | |
| switch (activation) { | |
| case Activation::None: return "None"; | |
| case Activation::ReLU: return "ReLU"; | |
| case Activation::LeakyReLU: return "LeakyReLU"; | |
| case Activation::Exponential: return "Exponential"; | |
| case Activation::Sigmoid: return "Sigmoid"; | |
| case Activation::Sine: return "Sine"; | |
| case Activation::Squareplus: return "Squareplus"; | |
| case Activation::Softplus: return "Softplus"; | |
| case Activation::Tanh: return "Tanh"; | |
| default: throw std::runtime_error{"Invalid activation."}; | |
| } | |
| } | |
| GridType string_to_grid_type(const std::string& grid_type) { | |
| if (equals_case_insensitive(grid_type, "Hash")) { | |
| return GridType::Hash; | |
| } else if (equals_case_insensitive(grid_type, "Dense")) { | |
| return GridType::Dense; | |
| } else if (equals_case_insensitive(grid_type, "Tiled") || equals_case_insensitive(grid_type, "Tile")) { | |
| return GridType::Tiled; | |
| } | |
| throw std::runtime_error{fmt::format("Invalid grid type: {}", grid_type)}; | |
| } | |
| std::string to_string(GridType grid_type) { | |
| switch (grid_type) { | |
| case GridType::Hash: return "Hash"; | |
| case GridType::Dense: return "Dense"; | |
| case GridType::Tiled: return "Tiled"; | |
| default: throw std::runtime_error{"Invalid grid type."}; | |
| } | |
| } | |
| HashType string_to_hash_type(const std::string& hash_type) { | |
| if (equals_case_insensitive(hash_type, "Prime")) { | |
| return HashType::Prime; | |
| } else if (equals_case_insensitive(hash_type, "CoherentPrime")) { | |
| return HashType::CoherentPrime; | |
| } else if (equals_case_insensitive(hash_type, "ReversedPrime")) { | |
| return HashType::ReversedPrime; | |
| } else if (equals_case_insensitive(hash_type, "Rng")) { | |
| return HashType::Rng; | |
| } else if (equals_case_insensitive(hash_type, "BaseConvert")) { | |
| return HashType::BaseConvert; | |
| } | |
| throw std::runtime_error{fmt::format("Invalid hash type: {}", hash_type)}; | |
| } | |
| std::string to_string(HashType hash_type) { | |
| switch (hash_type) { | |
| case HashType::Prime: return "Prime"; | |
| case HashType::CoherentPrime: return "CoherentPrime"; | |
| case HashType::ReversedPrime: return "ReversedPrime"; | |
| case HashType::Rng: return "Rng"; | |
| case HashType::BaseConvert: return "BaseConvert"; | |
| default: throw std::runtime_error{"Invalid hash type."}; | |
| } | |
| } | |
| InterpolationType string_to_interpolation_type(const std::string& interpolation_type) { | |
| if (equals_case_insensitive(interpolation_type, "Nearest")) { | |
| return InterpolationType::Nearest; | |
| } else if (equals_case_insensitive(interpolation_type, "Linear")) { | |
| return InterpolationType::Linear; | |
| } else if (equals_case_insensitive(interpolation_type, "Smoothstep")) { | |
| return InterpolationType::Smoothstep; | |
| } | |
| throw std::runtime_error{fmt::format("Invalid interpolation type: {}", interpolation_type)}; | |
| } | |
| std::string to_string(InterpolationType interpolation_type) { | |
| switch (interpolation_type) { | |
| case InterpolationType::Nearest: return "Nearest"; | |
| case InterpolationType::Linear: return "Linear"; | |
| case InterpolationType::Smoothstep: return "Smoothstep"; | |
| default: throw std::runtime_error{"Invalid interpolation type."}; | |
| } | |
| } | |
| ReductionType string_to_reduction_type(const std::string& reduction_type) { | |
| if (equals_case_insensitive(reduction_type, "Concatenation")) { | |
| return ReductionType::Concatenation; | |
| } else if (equals_case_insensitive(reduction_type, "Sum")) { | |
| return ReductionType::Sum; | |
| } else if (equals_case_insensitive(reduction_type, "Product")) { | |
| return ReductionType::Product; | |
| } | |
| throw std::runtime_error{fmt::format("Invalid reduction type: {}", reduction_type)}; | |
| } | |
| std::string to_string(ReductionType reduction_type) { | |
| switch (reduction_type) { | |
| case ReductionType::Concatenation: return "Concatenation"; | |
| case ReductionType::Sum: return "Sum"; | |
| case ReductionType::Product: return "Product"; | |
| default: throw std::runtime_error{"Invalid reduction type."}; | |
| } | |
| } | |
| int cuda_runtime_version() { | |
| int version; | |
| CUDA_CHECK_THROW(cudaRuntimeGetVersion(&version)); | |
| return version; | |
| } | |
| int cuda_device() { | |
| int device; | |
| CUDA_CHECK_THROW(cudaGetDevice(&device)); | |
| return device; | |
| } | |
| void set_cuda_device(int device) { | |
| CUDA_CHECK_THROW(cudaSetDevice(device)); | |
| } | |
| int cuda_device_count() { | |
| int device_count; | |
| CUDA_CHECK_THROW(cudaGetDeviceCount(&device_count)); | |
| return device_count; | |
| } | |
| bool cuda_supports_virtual_memory(int device) { | |
| int supports_vmm; | |
| CU_CHECK_THROW(cuDeviceGetAttribute(&supports_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, device)); | |
| return supports_vmm != 0; | |
| } | |
| std::unordered_map<int, cudaDeviceProp>& cuda_device_properties() { | |
| static auto* cuda_device_props = new std::unordered_map<int, cudaDeviceProp>{}; | |
| return *cuda_device_props; | |
| } | |
| const cudaDeviceProp& cuda_get_device_properties(int device) { | |
| if (cuda_device_properties().count(device) == 0) { | |
| auto& props = cuda_device_properties()[device]; | |
| CUDA_CHECK_THROW(cudaGetDeviceProperties(&props, device)); | |
| } | |
| return cuda_device_properties().at(device); | |
| } | |
| std::string cuda_device_name(int device) { | |
| return cuda_get_device_properties(device).name; | |
| } | |
| uint32_t cuda_compute_capability(int device) { | |
| const auto& props = cuda_get_device_properties(device); | |
| return props.major * 10 + props.minor; | |
| } | |
| uint32_t cuda_max_supported_compute_capability() { | |
| int cuda_version = cuda_runtime_version(); | |
| if (cuda_version < 11000) { | |
| return 75; | |
| } else if (cuda_version < 11010) { | |
| return 80; | |
| } else if (cuda_version < 11080) { | |
| return 86; | |
| } else { | |
| return 90; | |
| } | |
| } | |
| uint32_t cuda_supported_compute_capability(int device) { | |
| return std::min(cuda_compute_capability(device), cuda_max_supported_compute_capability()); | |
| } | |
| size_t cuda_max_shmem(int device) { | |
| return cuda_get_device_properties(device).sharedMemPerBlockOptin; | |
| } | |
| uint32_t cuda_max_registers(int device) { | |
| return (uint32_t)cuda_get_device_properties(device).regsPerBlock; | |
| } | |
| size_t cuda_memory_granularity(int device) { | |
| size_t granularity; | |
| CUmemAllocationProp prop = {}; | |
| prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; | |
| prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; | |
| prop.location.id = 0; | |
| CUresult granularity_result = cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM); | |
| if (granularity_result == CUDA_ERROR_NOT_SUPPORTED) { | |
| return 1; | |
| } | |
| CU_CHECK_THROW(granularity_result); | |
| return granularity; | |
| } | |
| MemoryInfo cuda_memory_info() { | |
| MemoryInfo info; | |
| CUDA_CHECK_THROW(cudaMemGetInfo(&info.free, &info.total)); | |
| info.used = info.total - info.free; | |
| return info; | |
| } | |
| std::string generate_device_code_preamble() { | |
| return dfmt(0, R"( | |
| #include <tiny-cuda-nn/common_device.h> | |
| #include <tiny-cuda-nn/mma.h> | |
| using namespace tcnn; | |
| )"); | |
| } | |
| std::string to_snake_case(const std::string& str) { | |
| std::stringstream result; | |
| result << (char)std::tolower(str[0]); | |
| for (uint32_t i = 1; i < str.length(); ++i) { | |
| if (std::isupper(str[i])) { | |
| result << "_" << (char)std::tolower(str[i]); | |
| } else { | |
| result << str[i]; | |
| } | |
| } | |
| return result.str(); | |
| } | |
| std::vector<std::string> split(const std::string& text, const std::string& delim) { | |
| std::vector<std::string> result; | |
| size_t begin = 0; | |
| while (true) { | |
| size_t end = text.find_first_of(delim, begin); | |
| if (end == std::string::npos) { | |
| result.emplace_back(text.substr(begin)); | |
| return result; | |
| } else { | |
| result.emplace_back(text.substr(begin, end - begin)); | |
| begin = end + 1; | |
| } | |
| } | |
| return result; | |
| } | |
| std::string to_lower(std::string str) { | |
| std::transform(std::begin(str), std::end(str), std::begin(str), [](unsigned char c) { return (char)std::tolower(c); }); | |
| return str; | |
| } | |
| std::string to_upper(std::string str) { | |
| std::transform(std::begin(str), std::end(str), std::begin(str), [](unsigned char c) { return (char)std::toupper(c); }); | |
| return str; | |
| } | |
| template <> std::string type_to_string<bool>() { return "bool"; } | |
| template <> std::string type_to_string<int>() { return "int"; } | |
| template <> std::string type_to_string<char>() { return "char"; } | |
| template <> std::string type_to_string<uint8_t>() { return "uint8_t"; } | |
| template <> std::string type_to_string<uint16_t>() { return "uint16_t"; } | |
| template <> std::string type_to_string<uint32_t>() { return "uint32_t"; } | |
| template <> std::string type_to_string<double>() { return "double"; } | |
| template <> std::string type_to_string<float>() { return "float"; } | |
| template <> std::string type_to_string<__half>() { return "__half"; } | |
| } | |