Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_graph_sycl.h +99 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_graph_sycl.hpp +131 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_graph_types.h +475 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_ocl.h +276 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_ocl.hpp +445 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_ocl_types.h +51 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_sycl.h +199 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_sycl.hpp +384 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_sycl_types.h +51 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool.h +118 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool.hpp +113 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool_iface.hpp +73 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_types.h +0 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel.h +337 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel.hpp +465 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel_types.h +93 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_version.h +33 -0
- phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_version_hash.h +31 -0
- phivenv/Lib/site-packages/torch/include/pybind11/attr.h +690 -0
- phivenv/Lib/site-packages/torch/include/pybind11/buffer_info.h +208 -0
- phivenv/Lib/site-packages/torch/include/pybind11/cast.h +1855 -0
- phivenv/Lib/site-packages/torch/include/pybind11/chrono.h +225 -0
- phivenv/Lib/site-packages/torch/include/pybind11/common.h +2 -0
- phivenv/Lib/site-packages/torch/include/pybind11/complex.h +74 -0
- phivenv/Lib/site-packages/torch/include/pybind11/detail/class.h +767 -0
- phivenv/Lib/site-packages/torch/include/pybind11/detail/common.h +1287 -0
- phivenv/Lib/site-packages/torch/include/pybind11/detail/cpp_conduit.h +77 -0
- phivenv/Lib/site-packages/torch/include/pybind11/detail/descr.h +172 -0
- phivenv/Lib/site-packages/torch/include/pybind11/detail/exception_translation.h +71 -0
- phivenv/Lib/site-packages/torch/include/pybind11/detail/init.h +436 -0
- phivenv/Lib/site-packages/torch/include/pybind11/detail/internals.h +766 -0
- phivenv/Lib/site-packages/torch/include/pybind11/detail/type_caster_base.h +1195 -0
- phivenv/Lib/site-packages/torch/include/pybind11/detail/typeid.h +65 -0
- phivenv/Lib/site-packages/torch/include/pybind11/detail/value_and_holder.h +77 -0
- phivenv/Lib/site-packages/torch/include/pybind11/eigen.h +12 -0
- phivenv/Lib/site-packages/torch/include/pybind11/eigen/common.h +9 -0
- phivenv/Lib/site-packages/torch/include/pybind11/eigen/matrix.h +715 -0
- phivenv/Lib/site-packages/torch/include/pybind11/eigen/tensor.h +515 -0
- phivenv/Lib/site-packages/torch/include/pybind11/embed.h +313 -0
- phivenv/Lib/site-packages/torch/include/pybind11/eval.h +156 -0
- phivenv/Lib/site-packages/torch/include/pybind11/functional.h +149 -0
- phivenv/Lib/site-packages/torch/include/pybind11/gil.h +219 -0
- phivenv/Lib/site-packages/torch/include/pybind11/gil_safe_call_once.h +100 -0
- phivenv/Lib/site-packages/torch/include/pybind11/iostream.h +265 -0
- phivenv/Lib/site-packages/torch/include/pybind11/numpy.h +2139 -0
- phivenv/Lib/site-packages/torch/include/pybind11/operators.h +202 -0
- phivenv/Lib/site-packages/torch/include/pybind11/options.h +92 -0
- phivenv/Lib/site-packages/torch/include/pybind11/pybind11.h +0 -0
- phivenv/Lib/site-packages/torch/include/pybind11/pytypes.h +0 -0
- phivenv/Lib/site-packages/torch/include/pybind11/stl.h +448 -0
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_graph_sycl.h
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2020-2024 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
#ifndef ONEAPI_DNNL_DNNL_GRAPH_SYCL_H
|
| 18 |
+
#define ONEAPI_DNNL_DNNL_GRAPH_SYCL_H
|
| 19 |
+
|
| 20 |
+
#include "oneapi/dnnl/dnnl_graph.h"
|
| 21 |
+
|
| 22 |
+
#ifdef __cplusplus
|
| 23 |
+
extern "C" {
|
| 24 |
+
#endif
|
| 25 |
+
|
| 26 |
+
/// @addtogroup dnnl_api
|
| 27 |
+
/// @{
|
| 28 |
+
|
| 29 |
+
/// @addtogroup dnnl_graph_api
|
| 30 |
+
/// @{
|
| 31 |
+
|
| 32 |
+
/// @addtogroup dnnl_graph_api_interop
|
| 33 |
+
/// @{
|
| 34 |
+
|
| 35 |
+
/// @addtogroup dnnl_graph_api_sycl_interop
|
| 36 |
+
/// @{
|
| 37 |
+
|
| 38 |
+
/// Allocation call-back function interface for SYCL. SYCL allocator should be
|
| 39 |
+
/// used for SYCL runtime and host allocator should be used for non-SYCL. The
|
| 40 |
+
/// call-back should return a USM device memory pointer.
|
| 41 |
+
typedef void *(*dnnl_graph_sycl_allocate_f)(
|
| 42 |
+
size_t size, size_t alignment, const void *dev, const void *context);
|
| 43 |
+
|
| 44 |
+
/// Deallocation call-back function interface for SYCL. SYCL allocator should be
|
| 45 |
+
/// used for SYCL runtime and host allocator should be used for non-SYCL. The
|
| 46 |
+
/// call-back should deallocate a USM device memory returned by
|
| 47 |
+
/// #dnnl_graph_sycl_allocate_f.
|
| 48 |
+
typedef void (*dnnl_graph_sycl_deallocate_f)(
|
| 49 |
+
void *buf, const void *dev, const void *context, void *event);
|
| 50 |
+
|
| 51 |
+
/// Creates an allocator with the given allocation and deallocation call-back
|
| 52 |
+
/// function pointers.
|
| 53 |
+
///
|
| 54 |
+
/// @param allocator Output allocator
|
| 55 |
+
/// @param sycl_malloc A pointer to SYCL malloc function
|
| 56 |
+
/// @param sycl_free A pointer to SYCL free function
|
| 57 |
+
/// @returns #dnnl_success on success and a status describing the
|
| 58 |
+
/// error otherwise.
|
| 59 |
+
dnnl_status_t DNNL_API dnnl_graph_sycl_interop_allocator_create(
|
| 60 |
+
dnnl_graph_allocator_t *allocator,
|
| 61 |
+
dnnl_graph_sycl_allocate_f sycl_malloc,
|
| 62 |
+
dnnl_graph_sycl_deallocate_f sycl_free);
|
| 63 |
+
|
| 64 |
+
/// This API is a supplement for existing onednn engine API.
|
| 65 |
+
dnnl_status_t DNNL_API dnnl_graph_sycl_interop_make_engine_with_allocator(
|
| 66 |
+
dnnl_engine_t *engine, const void *device, const void *context,
|
| 67 |
+
const_dnnl_graph_allocator_t alloc);
|
| 68 |
+
|
| 69 |
+
/// Execute a compiled partition with sycl runtime.
|
| 70 |
+
///
|
| 71 |
+
/// @param compiled_partition The handle of target compiled_partition.
|
| 72 |
+
/// @param stream The stream used for execution
|
| 73 |
+
/// @param num_inputs The number of input tensors
|
| 74 |
+
/// @param inputs A list of input tensors
|
| 75 |
+
/// @param num_outputs The number of output tensors
|
| 76 |
+
/// @param outputs A non-empty list of output tensors
|
| 77 |
+
/// @param deps Optional handle of list with `sycl::event` dependencies.
|
| 78 |
+
/// @param sycl_event The handle of sycl event.
|
| 79 |
+
/// @returns #dnnl_success on success and a status describing the
|
| 80 |
+
/// error otherwise.
|
| 81 |
+
dnnl_status_t DNNL_API dnnl_graph_sycl_interop_compiled_partition_execute(
|
| 82 |
+
const_dnnl_graph_compiled_partition_t compiled_partition,
|
| 83 |
+
dnnl_stream_t stream, size_t num_inputs,
|
| 84 |
+
const_dnnl_graph_tensor_t *inputs, size_t num_outputs,
|
| 85 |
+
const_dnnl_graph_tensor_t *outputs, const void *deps, void *sycl_event);
|
| 86 |
+
|
| 87 |
+
/// @} dnnl_graph_api_sycl_interop
|
| 88 |
+
|
| 89 |
+
/// @} dnnl_graph_api_interop
|
| 90 |
+
|
| 91 |
+
/// @} dnnl_graph_api
|
| 92 |
+
|
| 93 |
+
/// @} dnnl_api
|
| 94 |
+
|
| 95 |
+
#ifdef __cplusplus
|
| 96 |
+
}
|
| 97 |
+
#endif
|
| 98 |
+
|
| 99 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_graph_sycl.hpp
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2020-2025 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
/// @file
|
| 18 |
+
/// Graph SYCL interop API
|
| 19 |
+
|
| 20 |
+
#ifndef ONEAPI_DNNL_DNNL_GRAPH_SYCL_HPP
|
| 21 |
+
#define ONEAPI_DNNL_DNNL_GRAPH_SYCL_HPP
|
| 22 |
+
|
| 23 |
+
/// @cond DO_NOT_DOCUMENT_THIS
|
| 24 |
+
#include <vector>
|
| 25 |
+
|
| 26 |
+
#if __has_include(<sycl/sycl.hpp>)
|
| 27 |
+
#include <sycl/sycl.hpp>
|
| 28 |
+
#else
|
| 29 |
+
#error "Unsupported compiler"
|
| 30 |
+
#endif
|
| 31 |
+
|
| 32 |
+
#include "oneapi/dnnl/dnnl_graph.hpp"
|
| 33 |
+
#include "oneapi/dnnl/dnnl_graph_sycl.h"
|
| 34 |
+
/// @endcond
|
| 35 |
+
|
| 36 |
+
/// @addtogroup dnnl_api
|
| 37 |
+
/// @{
|
| 38 |
+
|
| 39 |
+
namespace dnnl {
|
| 40 |
+
|
| 41 |
+
/// @addtogroup dnnl_graph_api
|
| 42 |
+
/// @{
|
| 43 |
+
|
| 44 |
+
namespace graph {
|
| 45 |
+
|
| 46 |
+
/// @addtogroup dnnl_graph_api_interop Runtime interoperability API
|
| 47 |
+
/// API extensions to interact with the underlying run-time.
|
| 48 |
+
/// @{
|
| 49 |
+
|
| 50 |
+
/// @addtogroup dnnl_graph_api_sycl_interop SYCL interoperability API
|
| 51 |
+
/// API extensions to interact with the underlying SYCL run-time.
|
| 52 |
+
/// @{
|
| 53 |
+
|
| 54 |
+
/// SYCL interoperability namespace
|
| 55 |
+
namespace sycl_interop {
|
| 56 |
+
|
| 57 |
+
/// Constructs an allocator from SYCL malloc and free function pointer. SYCL
|
| 58 |
+
/// allocator should be used for SYCL runtime and host allocator should be used
|
| 59 |
+
/// for non-SYCL. Currently, only device USM allocator is supported.
|
| 60 |
+
///
|
| 61 |
+
/// @param sycl_malloc The pointer to SYCL malloc function
|
| 62 |
+
/// @param sycl_free The pointer to SYCL free function
|
| 63 |
+
/// @returns Created allocator
|
| 64 |
+
inline allocator make_allocator(dnnl_graph_sycl_allocate_f sycl_malloc,
|
| 65 |
+
dnnl_graph_sycl_deallocate_f sycl_free) {
|
| 66 |
+
dnnl_graph_allocator_t c_allocator = nullptr;
|
| 67 |
+
error::wrap_c_api(dnnl_graph_sycl_interop_allocator_create(
|
| 68 |
+
&c_allocator, sycl_malloc, sycl_free),
|
| 69 |
+
"could not create allocator for sycl device");
|
| 70 |
+
return allocator(c_allocator);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
inline engine make_engine_with_allocator(const sycl::device &adevice,
|
| 74 |
+
const sycl::context &acontext, const allocator &alloc) {
|
| 75 |
+
dnnl_engine_t c_engine;
|
| 76 |
+
error::wrap_c_api(
|
| 77 |
+
dnnl_graph_sycl_interop_make_engine_with_allocator(&c_engine,
|
| 78 |
+
static_cast<const void *>(&adevice),
|
| 79 |
+
static_cast<const void *>(&acontext), alloc.get()),
|
| 80 |
+
"could not make an engine with allocator");
|
| 81 |
+
return engine(c_engine);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
/// Executes a compiled partition in a specified stream and returns a SYCL
|
| 85 |
+
/// event.
|
| 86 |
+
///
|
| 87 |
+
/// @param c_partition Compiled partition to execute.
|
| 88 |
+
/// @param astream Stream object to run over
|
| 89 |
+
/// @param inputs Arguments map.
|
| 90 |
+
/// @param outputs Arguments map.
|
| 91 |
+
/// @param deps Optional vector with `sycl::event` dependencies.
|
| 92 |
+
/// @returns Output event.
|
| 93 |
+
inline sycl::event execute(compiled_partition &c_partition, stream &astream,
|
| 94 |
+
const std::vector<tensor> &inputs, std::vector<tensor> &outputs,
|
| 95 |
+
const std::vector<sycl::event> &deps = {}) {
|
| 96 |
+
std::vector<const_dnnl_graph_tensor_t> c_inputs;
|
| 97 |
+
c_inputs.reserve(inputs.size());
|
| 98 |
+
for (auto &in : inputs) {
|
| 99 |
+
c_inputs.push_back(in.get());
|
| 100 |
+
}
|
| 101 |
+
std::vector<const_dnnl_graph_tensor_t> c_outputs;
|
| 102 |
+
c_outputs.reserve(outputs.size());
|
| 103 |
+
for (auto &out : outputs) {
|
| 104 |
+
c_outputs.push_back(out.get());
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
sycl::event sycl_event;
|
| 108 |
+
error::wrap_c_api(dnnl_graph_sycl_interop_compiled_partition_execute(
|
| 109 |
+
c_partition.get(), astream.get(), c_inputs.size(),
|
| 110 |
+
c_inputs.data(), c_outputs.size(),
|
| 111 |
+
c_outputs.data(), &deps, &sycl_event),
|
| 112 |
+
"could not execute the compiled_partition on a specified sycl "
|
| 113 |
+
"stream");
|
| 114 |
+
return sycl_event;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
} // namespace sycl_interop
|
| 118 |
+
|
| 119 |
+
/// @} dnnl_graph_api_sycl_interop
|
| 120 |
+
|
| 121 |
+
/// @} dnnl_graph_api_interop
|
| 122 |
+
|
| 123 |
+
} // namespace graph
|
| 124 |
+
|
| 125 |
+
/// @} dnnl_graph_api
|
| 126 |
+
|
| 127 |
+
} // namespace dnnl
|
| 128 |
+
|
| 129 |
+
/// @} dnnl_api
|
| 130 |
+
|
| 131 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_graph_types.h
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2020-2025 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
/// @file
|
| 18 |
+
/// C API definitions
|
| 19 |
+
|
| 20 |
+
#ifndef ONEAPI_DNNL_DNNL_GRAPH_TYPES_H
|
| 21 |
+
#define ONEAPI_DNNL_DNNL_GRAPH_TYPES_H
|
| 22 |
+
|
| 23 |
+
#ifdef __cplusplus
|
| 24 |
+
extern "C" {
|
| 25 |
+
#endif
|
| 26 |
+
|
| 27 |
+
/// @cond DO_NOT_DOCUMENT_THIS
|
| 28 |
+
#include <limits.h>
|
| 29 |
+
#include <stddef.h>
|
| 30 |
+
|
| 31 |
+
#include "oneapi/dnnl/dnnl_common_types.h"
|
| 32 |
+
/// @endcond
|
| 33 |
+
|
| 34 |
+
/// @addtogroup dnnl_api
|
| 35 |
+
/// @{
|
| 36 |
+
|
| 37 |
+
/// @addtogroup dnnl_graph_api
|
| 38 |
+
/// @{
|
| 39 |
+
|
| 40 |
+
/// @addtogroup dnnl_graph_api_logical_tensor
|
| 41 |
+
/// @{
|
| 42 |
+
|
| 43 |
+
/// A wildcard value for number of dimensions which is unknown at a tensor or
|
| 44 |
+
/// operation creation time.
|
| 45 |
+
#define DNNL_GRAPH_UNKNOWN_NDIMS -1
|
| 46 |
+
|
| 47 |
+
/// A wildcard value for dimensions that are unknown at a tensor or operation
|
| 48 |
+
/// creation time.
|
| 49 |
+
#define DNNL_GRAPH_UNKNOWN_DIM INT64_MIN
|
| 50 |
+
|
| 51 |
+
/// Layout type specification
|
| 52 |
+
typedef enum {
|
| 53 |
+
/// Undefined layout type
|
| 54 |
+
dnnl_graph_layout_type_undef = 0,
|
| 55 |
+
/// Any means to let the library to decide the layout for a tensor during
|
| 56 |
+
/// partition compilation.
|
| 57 |
+
dnnl_graph_layout_type_any = 1,
|
| 58 |
+
/// Strided means that the layout of a tensor is determined by the strides
|
| 59 |
+
/// field in the logical tensor.
|
| 60 |
+
dnnl_graph_layout_type_strided = 2,
|
| 61 |
+
/// Opaque means that the layout of a tensor is the library specific.
|
| 62 |
+
/// Usually, an opaque layout is generated by a partition which is compiled
|
| 63 |
+
/// with layout type any.
|
| 64 |
+
dnnl_graph_layout_type_opaque = 3,
|
| 65 |
+
} dnnl_graph_layout_type_t;
|
| 66 |
+
|
| 67 |
+
/// Logical tensor property
|
| 68 |
+
typedef enum {
|
| 69 |
+
/// Undefined tensor property
|
| 70 |
+
dnnl_graph_tensor_property_undef = 0,
|
| 71 |
+
/// Variable means the tensor may be changed during computation or between
|
| 72 |
+
/// different iterations.
|
| 73 |
+
dnnl_graph_tensor_property_variable = 1,
|
| 74 |
+
/// Constant means the tensor will keep unchanged during computation and
|
| 75 |
+
/// between different iterations. It's useful for the library to apply
|
| 76 |
+
/// optimizations for constant tensors or cache constant tensors inside the
|
| 77 |
+
/// library. For example, constant weight tensors in inference scenarios.
|
| 78 |
+
dnnl_graph_tensor_property_constant = 2,
|
| 79 |
+
} dnnl_graph_tensor_property_t;
|
| 80 |
+
|
| 81 |
+
/// Logical tensor. It is based on an ID, a number of dimensions, dimensions
|
| 82 |
+
/// themselves, element data type, tensor property and tensor memory layout.
|
| 83 |
+
typedef struct {
|
| 84 |
+
/// Unique id of each logical tensor. The library uses logical tensor IDs to
|
| 85 |
+
/// build up the connections between operations if the output of one
|
| 86 |
+
/// operation has the same ID as the input of another operation.
|
| 87 |
+
size_t id;
|
| 88 |
+
|
| 89 |
+
/// Number of dimensions. -1 means unknown (DNNL_GRAPH_UNKNOWN_NDIMS). 0 is
|
| 90 |
+
/// used to define scalar tensor.
|
| 91 |
+
int ndims;
|
| 92 |
+
|
| 93 |
+
/// Size of each dimension. #DNNL_GRAPH_UNKNOWN_DIM means the size of that
|
| 94 |
+
/// dimension is unknown. 0 is used to define zero-dimension tensor. The
|
| 95 |
+
/// library supports to deduce output shapes according to input shapes
|
| 96 |
+
/// during compilation. Unlike memory descriptor in oneDNN primitive API,
|
| 97 |
+
/// the order of dimensions is not defined in logical tensor. It is defined
|
| 98 |
+
/// by the operations which respect the order through the attributes
|
| 99 |
+
/// #dnnl_graph_op_attr_data_format or #dnnl_graph_op_attr_weights_format.
|
| 100 |
+
/// For example, for a Convolution with `data_format=NXC`, it means the
|
| 101 |
+
/// first element of dims of activation tensor is mini-batch size, the last
|
| 102 |
+
/// effective element of dims is channel size, and other elements between
|
| 103 |
+
/// them are spatial dimensions.
|
| 104 |
+
dnnl_dims_t dims;
|
| 105 |
+
|
| 106 |
+
/// Data type of the tensor elements.
|
| 107 |
+
dnnl_data_type_t data_type;
|
| 108 |
+
|
| 109 |
+
/// Property type of the tensor.
|
| 110 |
+
dnnl_graph_tensor_property_t property;
|
| 111 |
+
|
| 112 |
+
/// Layout type of the tensor.
|
| 113 |
+
dnnl_graph_layout_type_t layout_type;
|
| 114 |
+
union {
|
| 115 |
+
/// The field is valid when `layout_type` is
|
| 116 |
+
/// #dnnl_graph_layout_type_strided. #DNNL_GRAPH_UNKNOWN_DIM means the
|
| 117 |
+
/// stride of the dimension is unknown. The library currently doesn't
|
| 118 |
+
/// support other negative stride values.
|
| 119 |
+
dnnl_dims_t strides;
|
| 120 |
+
|
| 121 |
+
/// The field is valid when `layout_type` is
|
| 122 |
+
/// #dnnl_graph_layout_type_opaque. An opaque layout ID is usually
|
| 123 |
+
/// generated by a partition which is compiled with layout type any.
|
| 124 |
+
size_t layout_id;
|
| 125 |
+
} layout;
|
| 126 |
+
} dnnl_graph_logical_tensor_t;
|
| 127 |
+
|
| 128 |
+
/// @} dnnl_graph_api_logical_tensor
|
| 129 |
+
|
| 130 |
+
/// @addtogroup dnnl_graph_api_partition
|
| 131 |
+
/// @{
|
| 132 |
+
|
| 133 |
+
/// Policy specifications for partitioning
|
| 134 |
+
typedef enum {
|
| 135 |
+
/// Fusion policy returns partitions with typical post-op fusions, eg.
|
| 136 |
+
/// Convolution + ReLU or other element-wise operations or a chian of
|
| 137 |
+
/// post-ops.
|
| 138 |
+
dnnl_graph_partition_policy_fusion = 1,
|
| 139 |
+
/// Debug policy doesn't not apply any fusions. It returns partitions with
|
| 140 |
+
/// single operation in each partition. The policy is useful when users
|
| 141 |
+
/// notice any bug or correctness issue in fusion policy.
|
| 142 |
+
dnnl_graph_partition_policy_debug = 2,
|
| 143 |
+
} dnnl_graph_partition_policy_t;
|
| 144 |
+
|
| 145 |
+
/// An opaque structure to describe a partition.
|
| 146 |
+
struct dnnl_graph_partition;
|
| 147 |
+
|
| 148 |
+
/// A partition handle.
|
| 149 |
+
typedef struct dnnl_graph_partition *dnnl_graph_partition_t;
|
| 150 |
+
|
| 151 |
+
/// A constant partition handle.
|
| 152 |
+
typedef const struct dnnl_graph_partition *const_dnnl_graph_partition_t;
|
| 153 |
+
|
| 154 |
+
/// @} dnnl_graph_api_partition
|
| 155 |
+
|
| 156 |
+
/// @addtogroup dnnl_graph_api_graph
|
| 157 |
+
/// @{
|
| 158 |
+
|
| 159 |
+
/// An opaque structure to describe a graph.
|
| 160 |
+
struct dnnl_graph_graph;
|
| 161 |
+
|
| 162 |
+
/// A graph handle.
|
| 163 |
+
typedef struct dnnl_graph_graph *dnnl_graph_graph_t;
|
| 164 |
+
|
| 165 |
+
/// A constant graph handle.
|
| 166 |
+
typedef const struct dnnl_graph_graph *const_dnnl_graph_graph_t;
|
| 167 |
+
|
| 168 |
+
/// @} dnnl_graph_api_graph
|
| 169 |
+
|
| 170 |
+
/// @addtogroup dnnl_graph_api_op
|
| 171 |
+
/// @{
|
| 172 |
+
|
| 173 |
+
/// Kinds of operations
|
| 174 |
+
typedef enum {
|
| 175 |
+
dnnl_graph_op_abs,
|
| 176 |
+
dnnl_graph_op_abs_backward,
|
| 177 |
+
dnnl_graph_op_add,
|
| 178 |
+
dnnl_graph_op_avg_pool,
|
| 179 |
+
dnnl_graph_op_avg_pool_backward,
|
| 180 |
+
dnnl_graph_op_batch_norm_backward,
|
| 181 |
+
dnnl_graph_op_batch_norm_forward_training,
|
| 182 |
+
dnnl_graph_op_batch_norm_inference,
|
| 183 |
+
dnnl_graph_op_bias_add,
|
| 184 |
+
dnnl_graph_op_bias_add_backward,
|
| 185 |
+
dnnl_graph_op_clamp,
|
| 186 |
+
dnnl_graph_op_clamp_backward,
|
| 187 |
+
dnnl_graph_op_concat,
|
| 188 |
+
dnnl_graph_op_convolution,
|
| 189 |
+
dnnl_graph_op_convolution_backward_data,
|
| 190 |
+
dnnl_graph_op_convolution_backward_weights,
|
| 191 |
+
dnnl_graph_op_conv_transpose,
|
| 192 |
+
dnnl_graph_op_conv_transpose_backward_data,
|
| 193 |
+
dnnl_graph_op_conv_transpose_backward_weights,
|
| 194 |
+
dnnl_graph_op_dequantize,
|
| 195 |
+
dnnl_graph_op_divide,
|
| 196 |
+
dnnl_graph_op_dynamic_dequantize,
|
| 197 |
+
dnnl_graph_op_dynamic_quantize,
|
| 198 |
+
dnnl_graph_op_elu,
|
| 199 |
+
dnnl_graph_op_elu_backward,
|
| 200 |
+
dnnl_graph_op_end,
|
| 201 |
+
dnnl_graph_op_exp,
|
| 202 |
+
dnnl_graph_op_gelu,
|
| 203 |
+
dnnl_graph_op_gelu_backward,
|
| 204 |
+
dnnl_graph_op_hard_swish,
|
| 205 |
+
dnnl_graph_op_hard_swish_backward,
|
| 206 |
+
dnnl_graph_op_interpolate,
|
| 207 |
+
dnnl_graph_op_interpolate_backward,
|
| 208 |
+
dnnl_graph_op_layer_norm,
|
| 209 |
+
dnnl_graph_op_layer_norm_backward,
|
| 210 |
+
dnnl_graph_op_leaky_relu,
|
| 211 |
+
dnnl_graph_op_log,
|
| 212 |
+
dnnl_graph_op_log_softmax,
|
| 213 |
+
dnnl_graph_op_log_softmax_backward,
|
| 214 |
+
dnnl_graph_op_matmul,
|
| 215 |
+
dnnl_graph_op_maximum,
|
| 216 |
+
dnnl_graph_op_max_pool,
|
| 217 |
+
dnnl_graph_op_max_pool_backward,
|
| 218 |
+
dnnl_graph_op_minimum,
|
| 219 |
+
dnnl_graph_op_mish,
|
| 220 |
+
dnnl_graph_op_mish_backward,
|
| 221 |
+
dnnl_graph_op_multiply,
|
| 222 |
+
dnnl_graph_op_prelu,
|
| 223 |
+
dnnl_graph_op_prelu_backward,
|
| 224 |
+
dnnl_graph_op_quantize,
|
| 225 |
+
dnnl_graph_op_reciprocal,
|
| 226 |
+
dnnl_graph_op_reduce_l1,
|
| 227 |
+
dnnl_graph_op_reduce_l2,
|
| 228 |
+
dnnl_graph_op_reduce_max,
|
| 229 |
+
dnnl_graph_op_reduce_mean,
|
| 230 |
+
dnnl_graph_op_reduce_min,
|
| 231 |
+
dnnl_graph_op_reduce_prod,
|
| 232 |
+
dnnl_graph_op_reduce_sum,
|
| 233 |
+
dnnl_graph_op_relu,
|
| 234 |
+
dnnl_graph_op_relu_backward,
|
| 235 |
+
dnnl_graph_op_reorder,
|
| 236 |
+
dnnl_graph_op_round,
|
| 237 |
+
dnnl_graph_op_sigmoid,
|
| 238 |
+
dnnl_graph_op_sigmoid_backward,
|
| 239 |
+
dnnl_graph_op_softmax,
|
| 240 |
+
dnnl_graph_op_softmax_backward,
|
| 241 |
+
dnnl_graph_op_softplus,
|
| 242 |
+
dnnl_graph_op_softplus_backward,
|
| 243 |
+
dnnl_graph_op_sqrt,
|
| 244 |
+
dnnl_graph_op_sqrt_backward,
|
| 245 |
+
dnnl_graph_op_square,
|
| 246 |
+
dnnl_graph_op_squared_difference,
|
| 247 |
+
dnnl_graph_op_static_reshape,
|
| 248 |
+
dnnl_graph_op_static_transpose,
|
| 249 |
+
dnnl_graph_op_subtract,
|
| 250 |
+
dnnl_graph_op_tanh,
|
| 251 |
+
dnnl_graph_op_tanh_backward,
|
| 252 |
+
dnnl_graph_op_type_cast,
|
| 253 |
+
dnnl_graph_op_wildcard,
|
| 254 |
+
dnnl_graph_op_hard_sigmoid,
|
| 255 |
+
dnnl_graph_op_hard_sigmoid_backward,
|
| 256 |
+
dnnl_graph_op_select,
|
| 257 |
+
dnnl_graph_op_pow,
|
| 258 |
+
dnnl_graph_op_group_norm,
|
| 259 |
+
dnnl_graph_op_gen_index,
|
| 260 |
+
dnnl_graph_op_greater_equal,
|
| 261 |
+
dnnl_graph_op_last_symbol,
|
| 262 |
+
} dnnl_graph_op_kind_t;
|
| 263 |
+
|
| 264 |
+
/// Attributes of operations
|
| 265 |
+
typedef enum {
|
| 266 |
+
/// Undefined op attribute.
|
| 267 |
+
dnnl_graph_op_attr_undef = 0,
|
| 268 |
+
|
| 269 |
+
// float32 attributes. The value of these attributes can be any single
|
| 270 |
+
// float32 number.
|
| 271 |
+
|
| 272 |
+
/// Specifies an alpha attribute to an op.
|
| 273 |
+
dnnl_graph_op_attr_alpha = 0x1,
|
| 274 |
+
/// Specifies an beta attribute to an op.
|
| 275 |
+
dnnl_graph_op_attr_beta,
|
| 276 |
+
/// Specifies an epsilon attribute to an op.
|
| 277 |
+
dnnl_graph_op_attr_epsilon,
|
| 278 |
+
/// Specifies a max attribute to an op.
|
| 279 |
+
dnnl_graph_op_attr_max,
|
| 280 |
+
///Specifies a min attribute to an op.
|
| 281 |
+
dnnl_graph_op_attr_min,
|
| 282 |
+
/// Specifies a momentum attribute to an op.
|
| 283 |
+
dnnl_graph_op_attr_momentum,
|
| 284 |
+
|
| 285 |
+
// float32 vector attributes. The value of these attributes can be a vector
|
| 286 |
+
// of float32 numbers.
|
| 287 |
+
|
| 288 |
+
/// Specifies a scales attribute to an op.
|
| 289 |
+
dnnl_graph_op_attr_scales = 0x20,
|
| 290 |
+
|
| 291 |
+
// int64_t attributes. The value of these attributes can be any single int64
|
| 292 |
+
// number.
|
| 293 |
+
|
| 294 |
+
/// Specifies an axis attribute to an op.
|
| 295 |
+
dnnl_graph_op_attr_axis = 0x30,
|
| 296 |
+
/// Specifies a begin_norm_axis attribute to an op.
|
| 297 |
+
dnnl_graph_op_attr_begin_norm_axis,
|
| 298 |
+
/// Specifies a groups attribute to an op.
|
| 299 |
+
dnnl_graph_op_attr_groups,
|
| 300 |
+
|
| 301 |
+
// int64_t vector attributes. The value of these attributes can be a vector
|
| 302 |
+
// of int64 numbers.
|
| 303 |
+
|
| 304 |
+
/// Specifies an axes attribute to an op.
|
| 305 |
+
dnnl_graph_op_attr_axes = 0x40,
|
| 306 |
+
/// Specifies a dilations attribute to an op.
|
| 307 |
+
dnnl_graph_op_attr_dilations,
|
| 308 |
+
/// Specifies an dst_shape attribute to an op.
|
| 309 |
+
dnnl_graph_op_attr_dst_shape,
|
| 310 |
+
/// Specifies a kernel attribute to an op.
|
| 311 |
+
dnnl_graph_op_attr_kernel,
|
| 312 |
+
/// Specifies an order attribute to an op.
|
| 313 |
+
dnnl_graph_op_attr_order,
|
| 314 |
+
/// Specifies an output_padding attribute to an op.
|
| 315 |
+
dnnl_graph_op_attr_output_padding,
|
| 316 |
+
/// Specifies a pads_begin attribute to an op.
|
| 317 |
+
dnnl_graph_op_attr_pads_begin,
|
| 318 |
+
/// Specifies a pads_end attribute to an op.
|
| 319 |
+
dnnl_graph_op_attr_pads_end,
|
| 320 |
+
/// Specifies a shape attribute to an op.
|
| 321 |
+
dnnl_graph_op_attr_shape,
|
| 322 |
+
/// Specifies a sizes attribute to an op.
|
| 323 |
+
dnnl_graph_op_attr_sizes,
|
| 324 |
+
/// Specifies a input_shape attribute to an op.
|
| 325 |
+
dnnl_graph_op_attr_src_shape,
|
| 326 |
+
/// Specifies a strides attribute to an op.
|
| 327 |
+
dnnl_graph_op_attr_strides,
|
| 328 |
+
/// Specifies a weight_shape attribute to an op.
|
| 329 |
+
dnnl_graph_op_attr_weights_shape,
|
| 330 |
+
/// Specifies a zps attribute to an op.
|
| 331 |
+
dnnl_graph_op_attr_zps,
|
| 332 |
+
/// Specifies a group shape attribute to an op.
|
| 333 |
+
dnnl_graph_op_attr_group_shape,
|
| 334 |
+
|
| 335 |
+
// bool attributes. The value of these attributes can be any single bool
|
| 336 |
+
// value.
|
| 337 |
+
|
| 338 |
+
/// Specifies an exclude_pad attribute to an op.
|
| 339 |
+
dnnl_graph_op_attr_exclude_pad = 0x60,
|
| 340 |
+
/// Specifies a keep_dims attribute to an op.
|
| 341 |
+
dnnl_graph_op_attr_keep_dims,
|
| 342 |
+
/// Specifies a keep_stats attribute to an op.
|
| 343 |
+
dnnl_graph_op_attr_keep_stats,
|
| 344 |
+
/// Specifies a per_channel_broadcast attribute to an op.
|
| 345 |
+
dnnl_graph_op_attr_per_channel_broadcast,
|
| 346 |
+
/// Specifies a special_zero attribute to an op.
|
| 347 |
+
dnnl_graph_op_attr_special_zero,
|
| 348 |
+
/// Specifies a transpose_a attribute to an op.
|
| 349 |
+
dnnl_graph_op_attr_transpose_a,
|
| 350 |
+
/// Specifies a transpose_b attribute to an op.
|
| 351 |
+
dnnl_graph_op_attr_transpose_b,
|
| 352 |
+
/// Specifies an use_affine attribute to an op.
|
| 353 |
+
dnnl_graph_op_attr_use_affine,
|
| 354 |
+
/// Specifies an use_dst attribute to an op.
|
| 355 |
+
dnnl_graph_op_attr_use_dst,
|
| 356 |
+
|
| 357 |
+
// string attributes. The value of these attributes can be a string.
|
| 358 |
+
|
| 359 |
+
/// Specifies an auto_broadcast attribute to an op. The value can be "none"
|
| 360 |
+
/// or "numpy".
|
| 361 |
+
dnnl_graph_op_attr_auto_broadcast = 0x80,
|
| 362 |
+
/// Specifies an auto_pad attribute to an op. The value can be "none",
|
| 363 |
+
/// "same_upper", "same_lower", or "valid".
|
| 364 |
+
dnnl_graph_op_attr_auto_pad,
|
| 365 |
+
/// Specifies an coordinate_transformation_mode attribute to an op. The
|
| 366 |
+
/// value can be "half_pixel" or "align_corners". The attribute is defined
|
| 367 |
+
/// for Interpolate operations.
|
| 368 |
+
dnnl_graph_op_attr_coordinate_transformation_mode,
|
| 369 |
+
/// Specifies a data_format of an op. The value can be "NCX" or "NXC".
|
| 370 |
+
dnnl_graph_op_attr_data_format,
|
| 371 |
+
/// Specifies a mode attribute of an op. The value can be "nearest",
|
| 372 |
+
/// "linear", "bilinear", or "trilinear". The attribute is defined for
|
| 373 |
+
/// Interpolate operations.
|
| 374 |
+
dnnl_graph_op_attr_mode,
|
| 375 |
+
/// Specifies a qtype attribute to an op. The value can be "per_channel" or
|
| 376 |
+
/// "per_tensor". The attribute is defined for quantization operations.
|
| 377 |
+
dnnl_graph_op_attr_qtype,
|
| 378 |
+
/// Specifies a rounding_type attribute to an op. The value can be "ceil" or
|
| 379 |
+
/// "floor".
|
| 380 |
+
dnnl_graph_op_attr_rounding_type,
|
| 381 |
+
/// Specifies a weights_format of an op. The value can be "OIX", "XIO",
|
| 382 |
+
/// "IOX", or "XOI". Different operations may support different values.
|
| 383 |
+
dnnl_graph_op_attr_weights_format,
|
| 384 |
+
|
| 385 |
+
/// Specifies the end of all above exteral attributes for check.
|
| 386 |
+
dnnl_graph_op_attr_end = 0xFF,
|
| 387 |
+
} dnnl_graph_op_attr_t;
|
| 388 |
+
|
| 389 |
+
/// An opaque structure to describe an operation.
|
| 390 |
+
struct dnnl_graph_op;
|
| 391 |
+
|
| 392 |
+
/// An operation handle.
|
| 393 |
+
typedef struct dnnl_graph_op *dnnl_graph_op_t;
|
| 394 |
+
|
| 395 |
+
/// A constant operation handle.
|
| 396 |
+
typedef const struct dnnl_graph_op *const_dnnl_graph_op_t;
|
| 397 |
+
|
| 398 |
+
/// @} dnnl_graph_api_op
|
| 399 |
+
|
| 400 |
+
/// @addtogroup dnnl_graph_api_allocator
|
| 401 |
+
/// @{
|
| 402 |
+
|
| 403 |
+
/// Allocation call-back function interface for host. For SYCL allocator, see
|
| 404 |
+
/// #dnnl_graph_sycl_allocate_f.
|
| 405 |
+
typedef void *(*dnnl_graph_host_allocate_f)(size_t size, size_t alignment);
|
| 406 |
+
|
| 407 |
+
/// Deallocation call-back function interface for host. For SYCL allocator, see
|
| 408 |
+
/// #dnnl_graph_sycl_deallocate_f.
|
| 409 |
+
typedef void (*dnnl_graph_host_deallocate_f)(void *);
|
| 410 |
+
|
| 411 |
+
/// An opaque structure to describe an allocator.
|
| 412 |
+
struct dnnl_graph_allocator;
|
| 413 |
+
|
| 414 |
+
/// An allocator handle.
|
| 415 |
+
typedef struct dnnl_graph_allocator *dnnl_graph_allocator_t;
|
| 416 |
+
|
| 417 |
+
/// A constant allocator handle.
|
| 418 |
+
typedef const struct dnnl_graph_allocator *const_dnnl_graph_allocator_t;
|
| 419 |
+
|
| 420 |
+
/// @} dnnl_graph_api_allocator
|
| 421 |
+
|
| 422 |
+
/// @addtogroup dnnl_graph_api_compiled_partition
|
| 423 |
+
/// @{
|
| 424 |
+
|
| 425 |
+
/// In-place pair definition. It can queried from a compiled partition
|
| 426 |
+
/// indicating that an input and an output of the partition can share the same
|
| 427 |
+
/// memory buffer for computation. In-place computation helps to reduce the
|
| 428 |
+
/// memory footprint and improves cache locality. But since the library may not
|
| 429 |
+
/// have a global view of user's application, it's possible that the tensor with
|
| 430 |
+
/// `input_id` is used at other places in user's computation graph. In this
|
| 431 |
+
/// case, the user should take the in-place pair as a hint and pass a different
|
| 432 |
+
/// memory buffer for output tensor to avoid overwriting the input memory buffer
|
| 433 |
+
/// which will probably cause unexpected incorrect results.
|
| 434 |
+
typedef struct {
|
| 435 |
+
/// The id of input tensor
|
| 436 |
+
size_t input_id;
|
| 437 |
+
|
| 438 |
+
/// The id of output tensor
|
| 439 |
+
size_t output_id;
|
| 440 |
+
} dnnl_graph_inplace_pair_t;
|
| 441 |
+
|
| 442 |
+
/// An opaque structure to describe a compiled partition.
|
| 443 |
+
struct dnnl_graph_compiled_partition;
|
| 444 |
+
|
| 445 |
+
/// A compiled partition handle.
|
| 446 |
+
typedef struct dnnl_graph_compiled_partition *dnnl_graph_compiled_partition_t;
|
| 447 |
+
|
| 448 |
+
/// A constant compiled partition handle.
|
| 449 |
+
typedef const struct dnnl_graph_compiled_partition
|
| 450 |
+
*const_dnnl_graph_compiled_partition_t;
|
| 451 |
+
|
| 452 |
+
/// @} dnnl_graph_api_compiled_partition
|
| 453 |
+
|
| 454 |
+
/// @addtogroup dnnl_graph_api_tensor
|
| 455 |
+
/// @{
|
| 456 |
+
|
| 457 |
+
/// An opaque structure to describe a tensor.
|
| 458 |
+
struct dnnl_graph_tensor;
|
| 459 |
+
|
| 460 |
+
/// A tensor handle.
|
| 461 |
+
typedef struct dnnl_graph_tensor *dnnl_graph_tensor_t;
|
| 462 |
+
|
| 463 |
+
/// A constant tensor handle.
|
| 464 |
+
typedef const struct dnnl_graph_tensor *const_dnnl_graph_tensor_t;
|
| 465 |
+
|
| 466 |
+
/// @} dnnl_graph_api_tensor
|
| 467 |
+
|
| 468 |
+
/// @} dnnl_graph_api
|
| 469 |
+
|
| 470 |
+
/// @} dnnl_api
|
| 471 |
+
|
| 472 |
+
#ifdef __cplusplus
|
| 473 |
+
}
|
| 474 |
+
#endif
|
| 475 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_ocl.h
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2020-2024 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
#ifndef ONEAPI_DNNL_DNNL_OCL_H
|
| 18 |
+
#define ONEAPI_DNNL_DNNL_OCL_H
|
| 19 |
+
|
| 20 |
+
#include "oneapi/dnnl/dnnl.h"
|
| 21 |
+
|
| 22 |
+
#include "oneapi/dnnl/dnnl_ocl_types.h"
|
| 23 |
+
|
| 24 |
+
/// @cond DO_NOT_DOCUMENT_THIS
|
| 25 |
+
// Set target version for OpenCL explicitly to suppress a compiler warning.
|
| 26 |
+
#ifndef CL_TARGET_OPENCL_VERSION
|
| 27 |
+
#define CL_TARGET_OPENCL_VERSION 120
|
| 28 |
+
#endif
|
| 29 |
+
|
| 30 |
+
#include <CL/cl.h>
|
| 31 |
+
/// @endcond
|
| 32 |
+
|
| 33 |
+
#ifdef __cplusplus
|
| 34 |
+
extern "C" {
|
| 35 |
+
#endif
|
| 36 |
+
|
| 37 |
+
/// @addtogroup dnnl_api
|
| 38 |
+
/// @{
|
| 39 |
+
|
| 40 |
+
/// @addtogroup dnnl_api_interop
|
| 41 |
+
/// @{
|
| 42 |
+
|
| 43 |
+
/// @addtogroup dnnl_api_ocl_interop
|
| 44 |
+
/// @{
|
| 45 |
+
|
| 46 |
+
/// Creates a memory object.
|
| 47 |
+
///
|
| 48 |
+
/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the
|
| 49 |
+
/// constructed memory object will have the underlying buffer set. In this
|
| 50 |
+
/// case, the buffer will be initialized as if:
|
| 51 |
+
/// - dnnl_memory_set_data_handle() has been called, if @p memory_kind is equal
|
| 52 |
+
/// to dnnl_ocl_interop_usm, or
|
| 53 |
+
/// - dnnl_ocl_interop_memory_set_mem_object() has been called, if @p memory_kind
|
| 54 |
+
/// is equal to dnnl_ocl_interop_buffer.
|
| 55 |
+
///
|
| 56 |
+
/// @param memory Output memory object.
|
| 57 |
+
/// @param memory_desc Memory descriptor.
|
| 58 |
+
/// @param engine Engine to use.
|
| 59 |
+
/// @param memory_kind Memory allocation kind to specify the type of handle.
|
| 60 |
+
/// @param handle Handle of the memory buffer to use as an underlying storage.
|
| 61 |
+
/// - A USM pointer to the user-allocated buffer. In this case the library
|
| 62 |
+
/// doesn't own the buffer. Requires @p memory_kind to be equal to
|
| 63 |
+
/// dnnl_ocl_interop_usm.
|
| 64 |
+
/// - An OpenCL buffer. In this case the library doesn't own the buffer.
|
| 65 |
+
/// Requires @p memory_kind be equal to be equal to dnnl_ocl_interop_buffer.
|
| 66 |
+
/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
| 67 |
+
/// allocate the buffer that corresponds to the memory allocation kind
|
| 68 |
+
/// @p memory_kind for the memory object. In this case the library
|
| 69 |
+
/// owns the buffer.
|
| 70 |
+
/// - The DNNL_MEMORY_NONE specific value. Instructs the library to
|
| 71 |
+
/// create memory object without an underlying buffer.
|
| 72 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 73 |
+
/// otherwise.
|
| 74 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_memory_create(dnnl_memory_t *memory,
|
| 75 |
+
const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
|
| 76 |
+
dnnl_ocl_interop_memory_kind_t memory_kind, void *handle);
|
| 77 |
+
|
| 78 |
+
#ifdef DNNL_EXPERIMENTAL_SPARSE
|
| 79 |
+
/// Creates a memory object with multiple handles.
|
| 80 |
+
///
|
| 81 |
+
/// @param memory Output memory object.
|
| 82 |
+
/// @param memory_desc Memory descriptor.
|
| 83 |
+
/// @param engine Engine to use.
|
| 84 |
+
/// @param memory_kind Memory allocation kind to specify the type of handles.
|
| 85 |
+
/// @param nhandles Number of handles.
|
| 86 |
+
/// @param handles Handles of the memory buffers to use as underlying storages.
|
| 87 |
+
/// For each element of the @p handles array the following applies:
|
| 88 |
+
/// - A USM pointer to the user-allocated buffer. In this case the library
|
| 89 |
+
/// doesn't own the buffer. Requires @p memory_kind to be equal to
|
| 90 |
+
/// dnnl_ocl_interop_usm.
|
| 91 |
+
/// - An OpenCL buffer. In this case the library doesn't own the buffer.
|
| 92 |
+
/// Requires @p memory_kind be equal to be equal to dnnl_ocl_interop_buffer.
|
| 93 |
+
/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
| 94 |
+
/// allocate the buffer that corresponds to the memory allocation kind
|
| 95 |
+
/// @p memory_kind for the memory object. In this case the library
|
| 96 |
+
/// owns the buffer.
|
| 97 |
+
/// - The DNNL_MEMORY_NONE specific value. Instructs the library to
|
| 98 |
+
/// create memory object without an underlying buffer.
|
| 99 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 100 |
+
/// otherwise.
|
| 101 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_memory_create_v2(dnnl_memory_t *memory,
|
| 102 |
+
const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
|
| 103 |
+
dnnl_ocl_interop_memory_kind_t memory_kind, int nhandles,
|
| 104 |
+
void **handles);
|
| 105 |
+
#endif
|
| 106 |
+
|
| 107 |
+
/// Returns the memory allocation kind associated with a memory object.
|
| 108 |
+
///
|
| 109 |
+
/// @param memory Memory to query.
|
| 110 |
+
/// @param memory_kind Output underlying memory allocation kind of the memory
|
| 111 |
+
/// object.
|
| 112 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 113 |
+
/// otherwise.
|
| 114 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_memory_get_memory_kind(
|
| 115 |
+
const_dnnl_memory_t memory,
|
| 116 |
+
dnnl_ocl_interop_memory_kind_t *memory_kind);
|
| 117 |
+
|
| 118 |
+
/// Returns an OpenCL memory object associated with a memory object.
|
| 119 |
+
///
|
| 120 |
+
/// @param memory Memory object.
|
| 121 |
+
/// @param mem_object Output OpenCL memory object.
|
| 122 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 123 |
+
/// otherwise.
|
| 124 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_memory_get_mem_object(
|
| 125 |
+
const_dnnl_memory_t memory, cl_mem *mem_object);
|
| 126 |
+
|
| 127 |
+
/// Sets OpenCL memory object associated with a memory object.
|
| 128 |
+
///
|
| 129 |
+
/// For behavioral details, see dnnl_memory_set_data_handle().
|
| 130 |
+
///
|
| 131 |
+
/// @param memory Memory object.
|
| 132 |
+
/// @param mem_object OpenCL memory object.
|
| 133 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 134 |
+
/// otherwise.
|
| 135 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_memory_set_mem_object(
|
| 136 |
+
dnnl_memory_t memory, cl_mem mem_object);
|
| 137 |
+
|
| 138 |
+
/// Retrieves a cache blob ID for the OpenCL device.
|
| 139 |
+
///
|
| 140 |
+
/// @warning
|
| 141 |
+
/// This API is intended to be used with
|
| 142 |
+
/// #dnnl_ocl_interop_engine_get_cache_blob() and
|
| 143 |
+
/// #dnnl_ocl_interop_engine_create_from_cache_blob(). The returned cache
|
| 144 |
+
/// blob ID can only be used as an ID of the cache blob returned by
|
| 145 |
+
/// #dnnl_ocl_interop_engine_get_cache_blob().
|
| 146 |
+
///
|
| 147 |
+
/// @note The cache blob ID can be empty (@p size will be 0 and
|
| 148 |
+
/// @p cache_blob_id will be nullptr) if oneDNN doesn't have anything to
|
| 149 |
+
/// put in the cache blob. (#dnnl_ocl_interop_engine_get_cache_blob will
|
| 150 |
+
/// return an empty cache blob).
|
| 151 |
+
///
|
| 152 |
+
/// @param device An OpenCL device.
|
| 153 |
+
/// @param size Size of the cache blob ID in bytes.
|
| 154 |
+
/// @param cache_blob_id Cache blob id of size @p size. If
|
| 155 |
+
/// the @p cache_blob_id is nullptr then the size of the cache blob ID is
|
| 156 |
+
/// returned in @p size.
|
| 157 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 158 |
+
/// otherwise.
|
| 159 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_engine_get_cache_blob_id(
|
| 160 |
+
cl_device_id device, size_t *size, uint8_t *cache_blob_id);
|
| 161 |
+
|
| 162 |
+
/// Retrieves a cache blob associated with the given engine.
|
| 163 |
+
///
|
| 164 |
+
/// @note The cache blob can be empty (@p size will be 0 and @p cache_blob
|
| 165 |
+
/// will be nullptr) if oneDNN doesn't have anything to put in the cache
|
| 166 |
+
/// blob. It's the user's responsibility to check whether it's empty
|
| 167 |
+
/// prior to passing it to
|
| 168 |
+
/// #dnnl_ocl_interop_engine_create_from_cache_blob().
|
| 169 |
+
///
|
| 170 |
+
/// @param engine Engine to query for the cache blob.
|
| 171 |
+
/// @param size Size of the cache blob in bytes.
|
| 172 |
+
/// @param cache_blob Cache blob of size @p size. If the @p cache_blob is
|
| 173 |
+
/// nullptr then the size of the cache blob is returned in @p size.
|
| 174 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 175 |
+
/// otherwise.
|
| 176 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_engine_get_cache_blob(
|
| 177 |
+
dnnl_engine_t engine, size_t *size, uint8_t *cache_blob);
|
| 178 |
+
|
| 179 |
+
/// Creates an engine from the given cache blob.
|
| 180 |
+
///
|
| 181 |
+
/// @param engine Output engine.
|
| 182 |
+
/// @param device The OpenCL device that this engine will encapsulate.
|
| 183 |
+
/// @param context The OpenCL context (containing the device) that this
|
| 184 |
+
/// engine will use for all operations.
|
| 185 |
+
/// @param size Size of the cache blob in bytes.
|
| 186 |
+
/// @param cache_blob Cache blob of size @p size.
|
| 187 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 188 |
+
/// otherwise.
|
| 189 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_engine_create_from_cache_blob(
|
| 190 |
+
dnnl_engine_t *engine, cl_device_id device, cl_context context,
|
| 191 |
+
size_t size, const uint8_t *cache_blob);
|
| 192 |
+
|
| 193 |
+
/// Creates an engine associated with an OpenCL device and an OpenCL context.
|
| 194 |
+
///
|
| 195 |
+
/// @param engine Output engine.
|
| 196 |
+
/// @param device Underlying OpenCL device to use for the engine.
|
| 197 |
+
/// @param context Underlying OpenCL context to use for the engine.
|
| 198 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 199 |
+
/// otherwise.
|
| 200 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_engine_create(
|
| 201 |
+
dnnl_engine_t *engine, cl_device_id device, cl_context context);
|
| 202 |
+
|
| 203 |
+
/// Returns the OpenCL context associated with an engine.
|
| 204 |
+
///
|
| 205 |
+
/// @param engine Engine to query.
|
| 206 |
+
/// @param context Output underlying OpenCL context of the engine.
|
| 207 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 208 |
+
/// otherwise.
|
| 209 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_engine_get_context(
|
| 210 |
+
dnnl_engine_t engine, cl_context *context);
|
| 211 |
+
|
| 212 |
+
/// Returns the OpenCL device associated with an engine.
|
| 213 |
+
///
|
| 214 |
+
/// @param engine Engine to query.
|
| 215 |
+
/// @param device Output underlying OpenCL device of the engine.
|
| 216 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 217 |
+
/// otherwise.
|
| 218 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_get_device(
|
| 219 |
+
dnnl_engine_t engine, cl_device_id *device);
|
| 220 |
+
|
| 221 |
+
/// Creates an execution stream for a given engine associated with
|
| 222 |
+
/// an OpenCL command queue.
|
| 223 |
+
///
|
| 224 |
+
/// @param stream Output execution stream.
|
| 225 |
+
/// @param engine Engine to create the execution stream on.
|
| 226 |
+
/// @param queue OpenCL command queue to use.
|
| 227 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 228 |
+
/// otherwise.
|
| 229 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_stream_create(
|
| 230 |
+
dnnl_stream_t *stream, dnnl_engine_t engine, cl_command_queue queue);
|
| 231 |
+
|
| 232 |
+
/// Returns the OpenCL command queue associated with an execution stream.
|
| 233 |
+
///
|
| 234 |
+
/// @param stream Execution stream to query.
|
| 235 |
+
/// @param queue Output OpenCL command queue.
|
| 236 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 237 |
+
/// otherwise.
|
| 238 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_stream_get_command_queue(
|
| 239 |
+
dnnl_stream_t stream, cl_command_queue *queue);
|
| 240 |
+
|
| 241 |
+
/// Executes computations specified by the primitive in a specified stream and
|
| 242 |
+
/// returns an OpenCL event.
|
| 243 |
+
///
|
| 244 |
+
/// @param primitive Primitive to execute.
|
| 245 |
+
/// @param stream Stream to use.
|
| 246 |
+
/// @param nargs Number of arguments.
|
| 247 |
+
/// @param args Array of arguments. Each argument is an
|
| 248 |
+
/// <index, #dnnl_memory_t> pair. The index is one of the `DNNL_ARG_*`
|
| 249 |
+
/// values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see
|
| 250 |
+
/// #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory
|
| 251 |
+
/// descriptor as that returned by
|
| 252 |
+
/// #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index).
|
| 253 |
+
/// @param deps A pointer to a vector of size @p ndeps that contains
|
| 254 |
+
/// dependencies.
|
| 255 |
+
/// @param ndeps Number of dependencies.
|
| 256 |
+
/// @param return_event Output event. It's the user's responsibility to
|
| 257 |
+
/// manage lifetime of the event. Can be NULL. When @p stream is in-order
|
| 258 |
+
/// NULL will be returned.
|
| 259 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 260 |
+
/// otherwise.
|
| 261 |
+
dnnl_status_t DNNL_API dnnl_ocl_interop_primitive_execute(
|
| 262 |
+
const_dnnl_primitive_t primitive, dnnl_stream_t stream, int nargs,
|
| 263 |
+
const dnnl_exec_arg_t *args, const cl_event *deps, int ndeps,
|
| 264 |
+
cl_event *return_event);
|
| 265 |
+
|
| 266 |
+
/// @} dnnl_api_ocl_interop
|
| 267 |
+
|
| 268 |
+
/// @} dnnl_api_interop
|
| 269 |
+
|
| 270 |
+
/// @} dnnl_api
|
| 271 |
+
|
| 272 |
+
#ifdef __cplusplus
|
| 273 |
+
}
|
| 274 |
+
#endif
|
| 275 |
+
|
| 276 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_ocl.hpp
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2020-2025 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
#ifndef ONEAPI_DNNL_DNNL_OCL_HPP
|
| 18 |
+
#define ONEAPI_DNNL_DNNL_OCL_HPP
|
| 19 |
+
|
| 20 |
+
#include "oneapi/dnnl/dnnl.hpp"
|
| 21 |
+
|
| 22 |
+
/// @cond DO_NOT_DOCUMENT_THIS
|
| 23 |
+
#include <algorithm>
|
| 24 |
+
#include <cstdlib>
|
| 25 |
+
#include <iterator>
|
| 26 |
+
#include <memory>
|
| 27 |
+
#include <string>
|
| 28 |
+
#include <vector>
|
| 29 |
+
#include <unordered_map>
|
| 30 |
+
|
| 31 |
+
#include "oneapi/dnnl/dnnl_ocl.h"
|
| 32 |
+
|
| 33 |
+
#include <CL/cl.h>
|
| 34 |
+
/// @endcond
|
| 35 |
+
|
| 36 |
+
/// @addtogroup dnnl_api
|
| 37 |
+
/// @{
|
| 38 |
+
|
| 39 |
+
namespace dnnl {
|
| 40 |
+
|
| 41 |
+
/// @addtogroup dnnl_api_interop Runtime interoperability API
|
| 42 |
+
/// API extensions to interact with the underlying run-time.
|
| 43 |
+
/// @{
|
| 44 |
+
|
| 45 |
+
/// @addtogroup dnnl_api_ocl_interop OpenCL interoperability API
|
| 46 |
+
/// API extensions to interact with the underlying OpenCL run-time.
|
| 47 |
+
///
|
| 48 |
+
/// @sa @ref dev_guide_opencl_interoperability in developer guide
|
| 49 |
+
/// @{
|
| 50 |
+
|
| 51 |
+
/// OpenCL interoperability namespace
|
| 52 |
+
namespace ocl_interop {
|
| 53 |
+
|
| 54 |
+
/// Memory allocation kind.
|
| 55 |
+
enum class memory_kind {
|
| 56 |
+
/// USM (device, shared, host, or unknown) memory allocation kind.
|
| 57 |
+
usm = dnnl_ocl_interop_usm,
|
| 58 |
+
/// Buffer memory allocation kind - default.
|
| 59 |
+
buffer = dnnl_ocl_interop_buffer,
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
/// Converts a memory allocation kind enum value from C++ API to C API type.
|
| 63 |
+
///
|
| 64 |
+
/// @param akind C++ API memory allocation kind enum value.
|
| 65 |
+
/// @returns Corresponding C API memory allocation kind enum value.
|
| 66 |
+
inline dnnl_ocl_interop_memory_kind_t convert_to_c(memory_kind akind) {
|
| 67 |
+
return static_cast<dnnl_ocl_interop_memory_kind_t>(akind);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
/// Returns the cache blob ID of the OpenCL device.
|
| 71 |
+
///
|
| 72 |
+
/// @warning
|
| 73 |
+
/// This API is intended to be used with
|
| 74 |
+
/// #dnnl::ocl_interop::get_engine_cache_blob() and
|
| 75 |
+
/// #dnnl::ocl_interop::make_engine(cl_device_id, cl_context, const std::vector<uint8_t> &).
|
| 76 |
+
/// The returned cache blob ID can only be used as an ID of the cache blob
|
| 77 |
+
/// returned by #dnnl::ocl_interop::get_engine_cache_blob().
|
| 78 |
+
///
|
| 79 |
+
/// @note The cache blob ID can be empty (@p size will be 0 and
|
| 80 |
+
/// @p cache_blob_id will be nullptr) if oneDNN doesn't have anything to
|
| 81 |
+
/// put in the cache blob. (#dnnl_ocl_interop_engine_get_cache_blob will
|
| 82 |
+
/// return an empty cache blob).
|
| 83 |
+
///
|
| 84 |
+
/// @param device An OpenCL device.
|
| 85 |
+
/// @returns A vector containing the cache blob ID.
|
| 86 |
+
inline std::vector<uint8_t> get_engine_cache_blob_id(cl_device_id device) {
|
| 87 |
+
size_t size = 0;
|
| 88 |
+
error::wrap_c_api(
|
| 89 |
+
dnnl_ocl_interop_engine_get_cache_blob_id(device, &size, nullptr),
|
| 90 |
+
"could not get an engine cache blob id size");
|
| 91 |
+
|
| 92 |
+
std::vector<uint8_t> cache_blob_id(size);
|
| 93 |
+
error::wrap_c_api(dnnl_ocl_interop_engine_get_cache_blob_id(
|
| 94 |
+
device, &size, cache_blob_id.data()),
|
| 95 |
+
"could not get an engine cache blob id");
|
| 96 |
+
return cache_blob_id;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
/// Returns a cache blob for the engine.
|
| 100 |
+
///
|
| 101 |
+
/// @note The cache blob vector can be empty if oneDNN doesn't have anything
|
| 102 |
+
/// to put in the cache blob. It's the user's responsibility to check
|
| 103 |
+
/// whether it's empty prior to passing it to
|
| 104 |
+
/// #dnnl::ocl_interop::make_engine(cl_device_id, cl_context, const std::vector<uint8_t> &)
|
| 105 |
+
///
|
| 106 |
+
/// @param aengine Engine to query for the cache blob.
|
| 107 |
+
/// @returns Vector containing the cache blob.
|
| 108 |
+
inline std::vector<uint8_t> get_engine_cache_blob(const engine &aengine) {
|
| 109 |
+
size_t size = 0;
|
| 110 |
+
error::wrap_c_api(dnnl_ocl_interop_engine_get_cache_blob(
|
| 111 |
+
aengine.get(), &size, nullptr),
|
| 112 |
+
"could not get an engine cache blob size");
|
| 113 |
+
|
| 114 |
+
std::vector<uint8_t> cache_blob(size);
|
| 115 |
+
error::wrap_c_api(dnnl_ocl_interop_engine_get_cache_blob(
|
| 116 |
+
aengine.get(), &size, cache_blob.data()),
|
| 117 |
+
"could not get an engine cache blob");
|
| 118 |
+
return cache_blob;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
/// Constructs an engine from the given cache blob.
|
| 122 |
+
///
|
| 123 |
+
/// @param device The OpenCL device that this engine will encapsulate.
|
| 124 |
+
/// @param context The OpenCL context (containing the device) that this
|
| 125 |
+
/// engine will use for all operations.
|
| 126 |
+
/// @param cache_blob Cache blob.
|
| 127 |
+
/// @returns An engine.
|
| 128 |
+
inline engine make_engine(cl_device_id device, cl_context context,
|
| 129 |
+
const std::vector<uint8_t> &cache_blob) {
|
| 130 |
+
dnnl_engine_t c_engine;
|
| 131 |
+
error::wrap_c_api(
|
| 132 |
+
dnnl_ocl_interop_engine_create_from_cache_blob(&c_engine, device,
|
| 133 |
+
context, cache_blob.size(), cache_blob.data()),
|
| 134 |
+
"could not create an engine from cache blob");
|
| 135 |
+
return engine(c_engine);
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
/// Constructs an engine from OpenCL device and context objects.
|
| 139 |
+
///
|
| 140 |
+
/// @param device The OpenCL device that this engine will encapsulate.
|
| 141 |
+
/// @param context The OpenCL context (containing the device) that this
|
| 142 |
+
/// engine will use for all operations.
|
| 143 |
+
/// @returns An engine.
|
| 144 |
+
inline engine make_engine(cl_device_id device, cl_context context) {
|
| 145 |
+
dnnl_engine_t c_engine;
|
| 146 |
+
error::wrap_c_api(
|
| 147 |
+
dnnl_ocl_interop_engine_create(&c_engine, device, context),
|
| 148 |
+
"could not create an engine");
|
| 149 |
+
return engine(c_engine);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
/// Returns OpenCL context associated with the engine.
|
| 153 |
+
///
|
| 154 |
+
/// @param aengine An engine.
|
| 155 |
+
/// @returns Underlying OpenCL context.
|
| 156 |
+
inline cl_context get_context(const engine &aengine) {
|
| 157 |
+
cl_context context = nullptr;
|
| 158 |
+
error::wrap_c_api(
|
| 159 |
+
dnnl_ocl_interop_engine_get_context(aengine.get(), &context),
|
| 160 |
+
"could not get an OpenCL context from an engine");
|
| 161 |
+
return context;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
/// Returns OpenCL device associated with the engine.
|
| 165 |
+
///
|
| 166 |
+
/// @param aengine An engine.
|
| 167 |
+
/// @returns Underlying OpenCL device.
|
| 168 |
+
inline cl_device_id get_device(const engine &aengine) {
|
| 169 |
+
cl_device_id device = nullptr;
|
| 170 |
+
error::wrap_c_api(dnnl_ocl_interop_get_device(aengine.get(), &device),
|
| 171 |
+
"could not get an OpenCL device from an engine");
|
| 172 |
+
return device;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
/// Constructs an execution stream for the specified engine and OpenCL queue.
|
| 176 |
+
///
|
| 177 |
+
/// @param aengine Engine to create the stream on.
|
| 178 |
+
/// @param queue OpenCL queue to use for the stream.
|
| 179 |
+
/// @returns An execution stream.
|
| 180 |
+
inline stream make_stream(const engine &aengine, cl_command_queue queue) {
|
| 181 |
+
dnnl_stream_t c_stream;
|
| 182 |
+
error::wrap_c_api(
|
| 183 |
+
dnnl_ocl_interop_stream_create(&c_stream, aengine.get(), queue),
|
| 184 |
+
"could not create a stream");
|
| 185 |
+
return stream(c_stream);
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
/// Returns OpenCL queue object associated with the execution stream.
|
| 189 |
+
///
|
| 190 |
+
/// @param astream An execution stream.
|
| 191 |
+
/// @returns Underlying OpenCL queue.
|
| 192 |
+
inline cl_command_queue get_command_queue(const stream &astream) {
|
| 193 |
+
cl_command_queue queue = nullptr;
|
| 194 |
+
error::wrap_c_api(
|
| 195 |
+
dnnl_ocl_interop_stream_get_command_queue(astream.get(), &queue),
|
| 196 |
+
"could not get an OpenCL command queue from a stream");
|
| 197 |
+
return queue;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
/// Returns the OpenCL memory object associated with the memory object.
|
| 201 |
+
///
|
| 202 |
+
/// @param amemory A memory object.
|
| 203 |
+
/// @returns Underlying OpenCL memory object.
|
| 204 |
+
inline cl_mem get_mem_object(const memory &amemory) {
|
| 205 |
+
cl_mem mem_object;
|
| 206 |
+
error::wrap_c_api(
|
| 207 |
+
dnnl_ocl_interop_memory_get_mem_object(amemory.get(), &mem_object),
|
| 208 |
+
"could not get OpenCL buffer object from a memory object");
|
| 209 |
+
return mem_object;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
/// Sets the OpenCL memory object associated with the memory object.
|
| 213 |
+
///
|
| 214 |
+
/// For behavioral details see memory::set_data_handle().
|
| 215 |
+
///
|
| 216 |
+
/// @param amemory A memory object.
|
| 217 |
+
/// @param mem_object OpenCL cl_mem object to use as the underlying
|
| 218 |
+
/// storage. It must have at least get_desc().get_size() bytes
|
| 219 |
+
/// allocated.
|
| 220 |
+
inline void set_mem_object(memory &amemory, cl_mem mem_object) {
|
| 221 |
+
error::wrap_c_api(
|
| 222 |
+
dnnl_ocl_interop_memory_set_mem_object(amemory.get(), mem_object),
|
| 223 |
+
"could not set OpenCL buffer object from a memory object");
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
/// Returns the memory allocation kind associated with a memory object.
|
| 227 |
+
///
|
| 228 |
+
/// @param amemory A memory object.
|
| 229 |
+
///
|
| 230 |
+
/// @returns The underlying memory allocation kind of the memory object.
|
| 231 |
+
inline memory_kind get_memory_kind(const memory &amemory) {
|
| 232 |
+
dnnl_ocl_interop_memory_kind_t ckind;
|
| 233 |
+
error::wrap_c_api(
|
| 234 |
+
dnnl_ocl_interop_memory_get_memory_kind(amemory.get(), &ckind),
|
| 235 |
+
"could not get memory kind");
|
| 236 |
+
return static_cast<memory_kind>(ckind);
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
#ifdef DNNL_EXPERIMENTAL_SPARSE
|
| 240 |
+
/// Creates a memory object with multiple handles.
|
| 241 |
+
///
|
| 242 |
+
/// @param memory_desc Memory descriptor.
|
| 243 |
+
/// @param aengine Engine to use.
|
| 244 |
+
/// @param kind Memory allocation kind to specify the type of handles.
|
| 245 |
+
/// @param handles Handles of the memory buffers to use as underlying storages.
|
| 246 |
+
/// For each element of the @p handles array the following applies:
|
| 247 |
+
/// - A USM pointer to the user-allocated buffer. In this case the library
|
| 248 |
+
/// doesn't own the buffer. Requires @p memory_kind to be equal to
|
| 249 |
+
/// dnnl_ocl_interop_usm.
|
| 250 |
+
/// - An OpenCL buffer. In this case the library doesn't own the buffer.
|
| 251 |
+
/// Requires @p memory_kind be equal to be equal to dnnl_ocl_interop_buffer.
|
| 252 |
+
/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
| 253 |
+
/// allocate the buffer that corresponds to the memory allocation kind
|
| 254 |
+
/// @p memory_kind for the memory object. In this case the library
|
| 255 |
+
/// owns the buffer.
|
| 256 |
+
/// - The DNNL_MEMORY_NONE specific value. Instructs the library to
|
| 257 |
+
/// create memory object without an underlying buffer.
|
| 258 |
+
///
|
| 259 |
+
/// If the @p handles vector is not provided the library will allocate all
|
| 260 |
+
/// buffers as if all handles have the special value DNNL_MEMORY_ALLOCATE.
|
| 261 |
+
///
|
| 262 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 263 |
+
/// otherwise.
|
| 264 |
+
inline memory make_memory(const memory::desc &memory_desc,
|
| 265 |
+
const engine &aengine, memory_kind kind,
|
| 266 |
+
std::vector<void *> handles = {}) {
|
| 267 |
+
if (handles.empty()) {
|
| 268 |
+
const int nhandles = memory_desc.get_num_handles();
|
| 269 |
+
handles.resize(nhandles, DNNL_MEMORY_ALLOCATE);
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
dnnl_memory_t c_memory;
|
| 273 |
+
error::wrap_c_api(
|
| 274 |
+
dnnl_ocl_interop_memory_create_v2(&c_memory, memory_desc.get(),
|
| 275 |
+
aengine.get(), convert_to_c(kind), (int)handles.size(),
|
| 276 |
+
handles.data()),
|
| 277 |
+
"could not create a memory");
|
| 278 |
+
return memory(c_memory);
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
/// Constructs a memory object with multiple OpenCL buffers.
|
| 282 |
+
///
|
| 283 |
+
/// @param memory_desc Memory descriptor.
|
| 284 |
+
/// @param aengine Engine to use.
|
| 285 |
+
/// @param mem_objects A vector of OpenCL buffers to use.
|
| 286 |
+
///
|
| 287 |
+
/// @returns Created memory object.
|
| 288 |
+
inline memory make_memory(const memory::desc &memory_desc,
|
| 289 |
+
const engine &aengine, std::vector<cl_mem> mem_objects) {
|
| 290 |
+
const int nhandles = memory_desc.get_num_handles();
|
| 291 |
+
std::vector<void *> handles(nhandles, DNNL_MEMORY_NONE);
|
| 292 |
+
memory amemory(memory_desc, aengine, handles);
|
| 293 |
+
for (int i = 0; i < nhandles; i++)
|
| 294 |
+
amemory.set_data_handle(mem_objects[i], i);
|
| 295 |
+
return amemory;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
/// Creates a memory object.
|
| 299 |
+
///
|
| 300 |
+
/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the
|
| 301 |
+
/// constructed memory object will have the underlying buffer set. In this
|
| 302 |
+
/// case, the buffer will be initialized as if:
|
| 303 |
+
/// - dnnl::memory::set_data_handle() had been called, if @p memory_kind is
|
| 304 |
+
/// equal to dnnl::ocl_interop::memory_kind::usm, or
|
| 305 |
+
/// - dnnl::ocl_interop::set_mem_object() has been called, if @p memory_kind is
|
| 306 |
+
/// equal to dnnl::ocl_interop::memory_kind::buffer.
|
| 307 |
+
///
|
| 308 |
+
/// @param memory_desc Memory descriptor.
|
| 309 |
+
/// @param aengine Engine to use.
|
| 310 |
+
/// @param kind Memory allocation kind to specify the type of handle.
|
| 311 |
+
/// @param handle Handle of the memory buffer to use as an underlying storage.
|
| 312 |
+
/// - A USM pointer to the user-allocated buffer. In this case the library
|
| 313 |
+
/// doesn't own the buffer. Requires @p memory_kind to be equal to
|
| 314 |
+
/// dnnl::ocl_interop::memory_kind::usm.
|
| 315 |
+
/// - An OpenCL buffer. In this case the library doesn't own the buffer.
|
| 316 |
+
/// Requires @p memory_kind be equal to be equal to
|
| 317 |
+
/// dnnl::ocl_interop::memory_kind::buffer.
|
| 318 |
+
/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
| 319 |
+
/// allocate the buffer that corresponds to the memory allocation kind
|
| 320 |
+
/// @p memory_kind for the memory object. In this case the library
|
| 321 |
+
/// owns the buffer.
|
| 322 |
+
/// - The DNNL_MEMORY_NONE specific value. Instructs the library to
|
| 323 |
+
/// create memory object without an underlying buffer.
|
| 324 |
+
///
|
| 325 |
+
/// @returns Created memory object.
|
| 326 |
+
inline memory make_memory(const memory::desc &memory_desc,
|
| 327 |
+
const engine &aengine, memory_kind kind, void *handle) {
|
| 328 |
+
return make_memory(
|
| 329 |
+
memory_desc, aengine, kind, std::vector<void *> {handle});
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
/// Constructs a memory object from an OpenCL buffer.
|
| 333 |
+
///
|
| 334 |
+
/// @param memory_desc Memory descriptor.
|
| 335 |
+
/// @param aengine Engine to use.
|
| 336 |
+
/// @param mem_object An OpenCL buffer to use.
|
| 337 |
+
///
|
| 338 |
+
/// @returns Created memory object.
|
| 339 |
+
inline memory make_memory(const memory::desc &memory_desc,
|
| 340 |
+
const engine &aengine, cl_mem mem_object) {
|
| 341 |
+
return make_memory(memory_desc, aengine, std::vector<cl_mem> {mem_object});
|
| 342 |
+
}
|
| 343 |
+
#else
|
| 344 |
+
|
| 345 |
+
/// Creates a memory object.
|
| 346 |
+
///
|
| 347 |
+
/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the
|
| 348 |
+
/// constructed memory object will have the underlying buffer set. In this
|
| 349 |
+
/// case, the buffer will be initialized as if:
|
| 350 |
+
/// - dnnl::memory::set_data_handle() had been called, if @p memory_kind is
|
| 351 |
+
/// equal to dnnl::ocl_interop::memory_kind::usm, or
|
| 352 |
+
/// - dnnl::ocl_interop::set_mem_object() has been called, if @p memory_kind is
|
| 353 |
+
/// equal to dnnl::ocl_interop::memory_kind::buffer.
|
| 354 |
+
///
|
| 355 |
+
/// @param memory_desc Memory descriptor.
|
| 356 |
+
/// @param aengine Engine to use.
|
| 357 |
+
/// @param kind Memory allocation kind to specify the type of handle.
|
| 358 |
+
/// @param handle Handle of the memory buffer to use as an underlying storage.
|
| 359 |
+
/// - A USM pointer to the user-allocated buffer. In this case the library
|
| 360 |
+
/// doesn't own the buffer. Requires @p memory_kind to be equal to
|
| 361 |
+
/// dnnl::ocl_interop::memory_kind::usm.
|
| 362 |
+
/// - An OpenCL buffer. In this case the library doesn't own the buffer.
|
| 363 |
+
/// Requires @p memory_kind be equal to be equal to
|
| 364 |
+
/// dnnl::ocl_interop::memory_kind::buffer.
|
| 365 |
+
/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
| 366 |
+
/// allocate the buffer that corresponds to the memory allocation kind
|
| 367 |
+
/// @p memory_kind for the memory object. In this case the library
|
| 368 |
+
/// owns the buffer.
|
| 369 |
+
/// - The DNNL_MEMORY_NONE specific value. Instructs the library to
|
| 370 |
+
/// create memory object without an underlying buffer.
|
| 371 |
+
///
|
| 372 |
+
/// @returns Created memory object.
|
| 373 |
+
inline memory make_memory(const memory::desc &memory_desc,
|
| 374 |
+
const engine &aengine, memory_kind kind,
|
| 375 |
+
void *handle = DNNL_MEMORY_ALLOCATE) {
|
| 376 |
+
dnnl_memory_t c_memory;
|
| 377 |
+
error::wrap_c_api(
|
| 378 |
+
dnnl_ocl_interop_memory_create(&c_memory, memory_desc.get(),
|
| 379 |
+
aengine.get(), convert_to_c(kind), handle),
|
| 380 |
+
"could not create a memory");
|
| 381 |
+
return memory(c_memory);
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
/// Constructs a memory object from an OpenCL buffer.
|
| 385 |
+
///
|
| 386 |
+
/// @param memory_desc Memory descriptor.
|
| 387 |
+
/// @param aengine Engine to use.
|
| 388 |
+
/// @param mem_object An OpenCL buffer to use.
|
| 389 |
+
///
|
| 390 |
+
/// @returns Created memory object.
|
| 391 |
+
inline memory make_memory(const memory::desc &memory_desc,
|
| 392 |
+
const engine &aengine, cl_mem mem_object) {
|
| 393 |
+
memory amemory(memory_desc, aengine, DNNL_MEMORY_NONE);
|
| 394 |
+
set_mem_object(amemory, mem_object);
|
| 395 |
+
return amemory;
|
| 396 |
+
}
|
| 397 |
+
#endif
|
| 398 |
+
|
| 399 |
+
/// Executes computations specified by the primitive in a specified stream and
|
| 400 |
+
/// returns a SYCL event.
|
| 401 |
+
///
|
| 402 |
+
/// Arguments are passed via an arguments map containing
|
| 403 |
+
/// <index, memory object> pairs. The index must be one of the `DNNL_ARG_*`
|
| 404 |
+
/// values such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor
|
| 405 |
+
/// matching the one returned by
|
| 406 |
+
/// #dnnl::primitive_desc::query_md(#query::exec_arg_md, index) unless using
|
| 407 |
+
/// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL).
|
| 408 |
+
///
|
| 409 |
+
/// @param aprimitive Primitive to execute.
|
| 410 |
+
/// @param astream Stream object. The stream must belong to the same engine
|
| 411 |
+
/// as the primitive.
|
| 412 |
+
/// @param args Arguments map.
|
| 413 |
+
/// @param deps Optional vector with `cl_event` dependencies.
|
| 414 |
+
///
|
| 415 |
+
/// @returns Output event. It's the user's responsibility to manage lifetime
|
| 416 |
+
/// of the event.
|
| 417 |
+
inline cl_event execute(const dnnl::primitive &aprimitive,
|
| 418 |
+
const stream &astream, const std::unordered_map<int, memory> &args,
|
| 419 |
+
const std::vector<cl_event> &deps = {}) {
|
| 420 |
+
std::vector<dnnl_exec_arg_t> c_args;
|
| 421 |
+
c_args.reserve(args.size());
|
| 422 |
+
for (const auto &a : args)
|
| 423 |
+
c_args.push_back({a.first, a.second.get()});
|
| 424 |
+
|
| 425 |
+
const cl_event *c_deps = deps.empty() ? nullptr : deps.data();
|
| 426 |
+
|
| 427 |
+
cl_event return_event;
|
| 428 |
+
error::wrap_c_api(dnnl_ocl_interop_primitive_execute(aprimitive.get(),
|
| 429 |
+
astream.get(), (int)c_args.size(), c_args.data(),
|
| 430 |
+
c_deps, (int)deps.size(), &return_event),
|
| 431 |
+
"could not execute a primitive");
|
| 432 |
+
return return_event;
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
} // namespace ocl_interop
|
| 436 |
+
|
| 437 |
+
/// @} dnnl_api_ocl_interop
|
| 438 |
+
|
| 439 |
+
/// @} dnnl_api_interop
|
| 440 |
+
|
| 441 |
+
} // namespace dnnl
|
| 442 |
+
|
| 443 |
+
/// @} dnnl_api
|
| 444 |
+
|
| 445 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_ocl_types.h
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2021 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
#ifndef ONEAPI_DNNL_DNNL_OCL_TYPES_H
|
| 18 |
+
#define ONEAPI_DNNL_DNNL_OCL_TYPES_H
|
| 19 |
+
|
| 20 |
+
#ifdef __cplusplus
|
| 21 |
+
extern "C" {
|
| 22 |
+
#endif
|
| 23 |
+
|
| 24 |
+
/// @addtogroup dnnl_api
|
| 25 |
+
/// @{
|
| 26 |
+
|
| 27 |
+
/// @addtogroup dnnl_api_interop
|
| 28 |
+
/// @{
|
| 29 |
+
|
| 30 |
+
/// @addtogroup dnnl_api_ocl_interop
|
| 31 |
+
/// @{
|
| 32 |
+
|
| 33 |
+
/// Memory allocation kind.
|
| 34 |
+
typedef enum {
|
| 35 |
+
/// USM (device, shared, host, or unknown) memory allocation kind.
|
| 36 |
+
dnnl_ocl_interop_usm,
|
| 37 |
+
/// Buffer memory allocation kind - default.
|
| 38 |
+
dnnl_ocl_interop_buffer,
|
| 39 |
+
} dnnl_ocl_interop_memory_kind_t;
|
| 40 |
+
|
| 41 |
+
/// @} dnnl_api_ocl_interop
|
| 42 |
+
|
| 43 |
+
/// @} dnnl_api_interop
|
| 44 |
+
|
| 45 |
+
/// @} dnnl_api
|
| 46 |
+
|
| 47 |
+
#ifdef __cplusplus
|
| 48 |
+
}
|
| 49 |
+
#endif
|
| 50 |
+
|
| 51 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_sycl.h
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2020-2024 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
#ifndef ONEAPI_DNNL_DNNL_SYCL_H
|
| 18 |
+
#define ONEAPI_DNNL_DNNL_SYCL_H
|
| 19 |
+
|
| 20 |
+
#include "oneapi/dnnl/dnnl.h"
|
| 21 |
+
|
| 22 |
+
#include "oneapi/dnnl/dnnl_sycl_types.h"
|
| 23 |
+
|
| 24 |
+
#ifdef __cplusplus
|
| 25 |
+
extern "C" {
|
| 26 |
+
#endif
|
| 27 |
+
|
| 28 |
+
/// @addtogroup dnnl_api
|
| 29 |
+
/// @{
|
| 30 |
+
|
| 31 |
+
/// @addtogroup dnnl_api_interop
|
| 32 |
+
/// @{
|
| 33 |
+
|
| 34 |
+
/// @addtogroup dnnl_api_sycl_interop
|
| 35 |
+
/// @{
|
| 36 |
+
|
| 37 |
+
/// Creates an engine associated with a SYCL device and a SYCL context.
|
| 38 |
+
///
|
| 39 |
+
/// @param engine Output engine.
|
| 40 |
+
/// @param device Pointer to the SYCL device to use for the engine.
|
| 41 |
+
/// @param context Pointer to the SYCL context to use for the engine.
|
| 42 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 43 |
+
/// otherwise.
|
| 44 |
+
dnnl_status_t DNNL_API dnnl_sycl_interop_engine_create(
|
| 45 |
+
dnnl_engine_t *engine, const void *device, const void *context);
|
| 46 |
+
|
| 47 |
+
/// Returns the SYCL context associated with an engine.
|
| 48 |
+
///
|
| 49 |
+
/// @param engine Engine to query.
|
| 50 |
+
/// @param context Pointer to the underlying SYCL context of the engine.
|
| 51 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 52 |
+
/// otherwise.
|
| 53 |
+
dnnl_status_t DNNL_API dnnl_sycl_interop_engine_get_context(
|
| 54 |
+
dnnl_engine_t engine, void **context);
|
| 55 |
+
|
| 56 |
+
/// Returns the SYCL device associated with an engine.
|
| 57 |
+
///
|
| 58 |
+
/// @param engine Engine to query.
|
| 59 |
+
/// @param device Pointer to the underlying SYCL device of the engine.
|
| 60 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 61 |
+
/// otherwise.
|
| 62 |
+
dnnl_status_t DNNL_API dnnl_sycl_interop_engine_get_device(
|
| 63 |
+
dnnl_engine_t engine, void **device);
|
| 64 |
+
|
| 65 |
+
/// Creates a memory object.
|
| 66 |
+
///
|
| 67 |
+
/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the
|
| 68 |
+
/// constructed memory object will have the underlying buffer set. In this
|
| 69 |
+
/// case, the buffer will be initialized as if:
|
| 70 |
+
/// - dnnl_memory_set_data_handle() had been called, if @p memory_kind is equal
|
| 71 |
+
/// to dnnl_sycl_interop_usm, or
|
| 72 |
+
/// - dnnl_sycl_interop_memory_set_buffer() has been called, if @p memory_kind
|
| 73 |
+
/// is equal to dnnl_sycl_interop_buffer.
|
| 74 |
+
///
|
| 75 |
+
/// @param memory Output memory object.
|
| 76 |
+
/// @param memory_desc Memory descriptor.
|
| 77 |
+
/// @param engine Engine to use.
|
| 78 |
+
/// @param memory_kind Memory allocation kind to specify the type of handle.
|
| 79 |
+
/// @param handle Handle of the memory buffer to use as an underlying storage.
|
| 80 |
+
/// - A USM pointer to the user-allocated buffer. In this case the library
|
| 81 |
+
/// doesn't own the buffer. Requires @p memory_kind to be equal to
|
| 82 |
+
/// dnnl_sycl_interop_usm.
|
| 83 |
+
/// - A pointer to SYCL buffer. In this case the library doesn't own the
|
| 84 |
+
/// buffer. Requires @p memory_kind be equal to be equal to
|
| 85 |
+
/// dnnl_sycl_interop_buffer.
|
| 86 |
+
/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
| 87 |
+
/// allocate the buffer that corresponds to the memory allocation kind
|
| 88 |
+
/// @p memory_kind for the memory object. In this case the library
|
| 89 |
+
/// owns the buffer.
|
| 90 |
+
/// - The DNNL_MEMORY_NONE specific value. Instructs the library to
|
| 91 |
+
/// create memory object without an underlying buffer.
|
| 92 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 93 |
+
/// otherwise.
|
| 94 |
+
dnnl_status_t DNNL_API dnnl_sycl_interop_memory_create(dnnl_memory_t *memory,
|
| 95 |
+
const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
|
| 96 |
+
dnnl_sycl_interop_memory_kind_t memory_kind, void *handle);
|
| 97 |
+
|
| 98 |
+
#ifdef DNNL_EXPERIMENTAL_SPARSE
|
| 99 |
+
/// Creates a memory object with multiple handles.
|
| 100 |
+
///
|
| 101 |
+
/// @param memory Output memory object.
|
| 102 |
+
/// @param memory_desc Memory descriptor.
|
| 103 |
+
/// @param engine Engine to use.
|
| 104 |
+
/// @param memory_kind Memory allocation kind to specify the type of handles.
|
| 105 |
+
/// @param nhandles Number of handles.
|
| 106 |
+
/// @param handles Handles of the memory buffers to use as underlying storages.
|
| 107 |
+
/// For each element of the @p handles array the following applies:
|
| 108 |
+
/// - A USM pointer to the user-allocated buffer. In this case the library
|
| 109 |
+
/// doesn't own the buffer. Requires @p memory_kind to be equal to
|
| 110 |
+
/// dnnl_sycl_interop_usm.
|
| 111 |
+
/// - A pointer to SYCL buffer. In this case the library doesn't own the
|
| 112 |
+
/// buffer. Requires @p memory_kind be equal to be equal to
|
| 113 |
+
/// dnnl_sycl_interop_buffer.
|
| 114 |
+
/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
| 115 |
+
/// allocate the buffer that corresponds to the memory allocation kind
|
| 116 |
+
/// @p memory_kind for the memory object. In this case the library
|
| 117 |
+
/// owns the buffer.
|
| 118 |
+
/// - The DNNL_MEMORY_NONE specific value. Instructs the library to
|
| 119 |
+
/// create memory object without an underlying buffer.
|
| 120 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 121 |
+
/// otherwise.
|
| 122 |
+
dnnl_status_t DNNL_API dnnl_sycl_interop_memory_create_v2(dnnl_memory_t *memory,
|
| 123 |
+
const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
|
| 124 |
+
dnnl_sycl_interop_memory_kind_t memory_kind, int nhandles,
|
| 125 |
+
void **handles);
|
| 126 |
+
#endif
|
| 127 |
+
|
| 128 |
+
/// Returns the memory allocation kind associated with a memory object.
|
| 129 |
+
///
|
| 130 |
+
/// @param memory Memory to query.
|
| 131 |
+
/// @param memory_kind Output underlying memory allocation kind of the memory
|
| 132 |
+
/// object.
|
| 133 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 134 |
+
/// otherwise.
|
| 135 |
+
dnnl_status_t DNNL_API dnnl_sycl_interop_memory_get_memory_kind(
|
| 136 |
+
const_dnnl_memory_t memory,
|
| 137 |
+
dnnl_sycl_interop_memory_kind_t *memory_kind);
|
| 138 |
+
|
| 139 |
+
/// Sets a SYCL buffer for a memory object.
|
| 140 |
+
///
|
| 141 |
+
/// @param memory Memory object.
|
| 142 |
+
/// @param buffer SYCL buffer to be set in the memory object.
|
| 143 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 144 |
+
/// otherwise.
|
| 145 |
+
dnnl_status_t DNNL_API dnnl_sycl_interop_memory_set_buffer(
|
| 146 |
+
dnnl_memory_t memory, void *buffer);
|
| 147 |
+
|
| 148 |
+
/// Creates an execution stream for a given engine associated with a SYCL
|
| 149 |
+
/// queue.
|
| 150 |
+
///
|
| 151 |
+
/// @param stream Output execution stream.
|
| 152 |
+
/// @param engine Engine to create the execution stream on.
|
| 153 |
+
/// @param queue SYCL queue to use.
|
| 154 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 155 |
+
/// otherwise.
|
| 156 |
+
dnnl_status_t DNNL_API dnnl_sycl_interop_stream_create(
|
| 157 |
+
dnnl_stream_t *stream, dnnl_engine_t engine, void *queue);
|
| 158 |
+
|
| 159 |
+
/// Returns the SYCL queue associated with an execution stream.
|
| 160 |
+
///
|
| 161 |
+
/// @param stream Execution stream to query.
|
| 162 |
+
/// @param queue Output SYCL command queue.
|
| 163 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 164 |
+
/// otherwise.
|
| 165 |
+
dnnl_status_t DNNL_API dnnl_sycl_interop_stream_get_queue(
|
| 166 |
+
dnnl_stream_t stream, void **queue);
|
| 167 |
+
|
| 168 |
+
/// Executes computations specified by the primitive in a specified stream and
|
| 169 |
+
/// returns a SYCL event.
|
| 170 |
+
///
|
| 171 |
+
/// @param primitive Primitive to execute.
|
| 172 |
+
/// @param stream Stream to use.
|
| 173 |
+
/// @param nargs Number of arguments.
|
| 174 |
+
/// @param args Array of arguments. Each argument is an
|
| 175 |
+
/// <index, #dnnl_memory_t> pair. The index is one of the `DNNL_ARG_*`
|
| 176 |
+
/// values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see
|
| 177 |
+
/// #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory
|
| 178 |
+
/// descriptor as that returned by
|
| 179 |
+
/// #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index).
|
| 180 |
+
/// @param deps A pointer to std::vector<sycl::event> that contains
|
| 181 |
+
/// dependencies.
|
| 182 |
+
/// @param return_event Output event.
|
| 183 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 184 |
+
/// otherwise.
|
| 185 |
+
dnnl_status_t DNNL_API dnnl_sycl_interop_primitive_execute(
|
| 186 |
+
const_dnnl_primitive_t primitive, dnnl_stream_t stream, int nargs,
|
| 187 |
+
const dnnl_exec_arg_t *args, const void *deps, void *return_event);
|
| 188 |
+
|
| 189 |
+
/// @} dnnl_api_sycl_interop
|
| 190 |
+
|
| 191 |
+
/// @} dnnl_api_interop
|
| 192 |
+
|
| 193 |
+
/// @} dnnl_api
|
| 194 |
+
|
| 195 |
+
#ifdef __cplusplus
|
| 196 |
+
}
|
| 197 |
+
#endif
|
| 198 |
+
|
| 199 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_sycl.hpp
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2020-2025 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
#ifndef ONEAPI_DNNL_DNNL_SYCL_HPP
|
| 18 |
+
#define ONEAPI_DNNL_DNNL_SYCL_HPP
|
| 19 |
+
|
| 20 |
+
/// @cond DO_NOT_DOCUMENT_THIS
|
| 21 |
+
#include <algorithm>
|
| 22 |
+
#include <cstdlib>
|
| 23 |
+
#include <iterator>
|
| 24 |
+
#include <memory>
|
| 25 |
+
#include <string>
|
| 26 |
+
#include <vector>
|
| 27 |
+
#include <unordered_map>
|
| 28 |
+
|
| 29 |
+
#if __has_include(<sycl/sycl.hpp>)
|
| 30 |
+
#include <sycl/sycl.hpp>
|
| 31 |
+
#else
|
| 32 |
+
#error "Unsupported compiler"
|
| 33 |
+
#endif
|
| 34 |
+
|
| 35 |
+
#include "oneapi/dnnl/dnnl.hpp"
|
| 36 |
+
#include "oneapi/dnnl/dnnl_sycl.h"
|
| 37 |
+
|
| 38 |
+
/// @endcond
|
| 39 |
+
|
| 40 |
+
/// @addtogroup dnnl_api
|
| 41 |
+
/// @{
|
| 42 |
+
|
| 43 |
+
namespace dnnl {
|
| 44 |
+
|
| 45 |
+
/// @addtogroup dnnl_api_interop
|
| 46 |
+
/// @{
|
| 47 |
+
|
| 48 |
+
/// @addtogroup dnnl_api_sycl_interop SYCL interoperability API
|
| 49 |
+
/// API extensions to interact with the underlying SYCL run-time.
|
| 50 |
+
///
|
| 51 |
+
/// @sa @ref dev_guide_dpcpp_interoperability in developer guide
|
| 52 |
+
/// @{
|
| 53 |
+
|
| 54 |
+
/// SYCL interoperability namespace
|
| 55 |
+
namespace sycl_interop {
|
| 56 |
+
|
| 57 |
+
/// Memory allocation kind.
|
| 58 |
+
enum class memory_kind {
|
| 59 |
+
/// USM (device, shared, host, or unknown) memory allocation kind - default.
|
| 60 |
+
usm = dnnl_sycl_interop_usm,
|
| 61 |
+
/// Buffer memory allocation kind.
|
| 62 |
+
buffer = dnnl_sycl_interop_buffer,
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
+
/// Converts a memory allocation kind enum value from C++ API to C API type.
|
| 66 |
+
///
|
| 67 |
+
/// @param akind C++ API memory allocation kind enum value.
|
| 68 |
+
/// @returns Corresponding C API memory allocation kind enum value.
|
| 69 |
+
inline dnnl_sycl_interop_memory_kind_t convert_to_c(memory_kind akind) {
|
| 70 |
+
return static_cast<dnnl_sycl_interop_memory_kind_t>(akind);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
/// Constructs an engine from SYCL device and context objects.
|
| 74 |
+
///
|
| 75 |
+
/// @param adevice SYCL device.
|
| 76 |
+
/// @param acontext SYCL context.
|
| 77 |
+
///
|
| 78 |
+
/// @returns Created engine.
|
| 79 |
+
inline engine make_engine(
|
| 80 |
+
const sycl::device &adevice, const sycl::context &acontext) {
|
| 81 |
+
dnnl_engine_t aengine;
|
| 82 |
+
error::wrap_c_api(dnnl_sycl_interop_engine_create(&aengine,
|
| 83 |
+
static_cast<const void *>(&adevice),
|
| 84 |
+
static_cast<const void *>(&acontext)),
|
| 85 |
+
"could not create an engine");
|
| 86 |
+
return engine(aengine);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
/// Returns the SYCL context associated with an engine.
|
| 90 |
+
///
|
| 91 |
+
/// @param aengine Engine to query.
|
| 92 |
+
///
|
| 93 |
+
/// @returns The underlying SYCL device of the engine.
|
| 94 |
+
inline sycl::context get_context(const engine &aengine) {
|
| 95 |
+
void *ctx_ptr;
|
| 96 |
+
error::wrap_c_api(
|
| 97 |
+
dnnl_sycl_interop_engine_get_context(aengine.get(), &ctx_ptr),
|
| 98 |
+
"could not get a context handle");
|
| 99 |
+
auto ctx = *static_cast<sycl::context *>(ctx_ptr);
|
| 100 |
+
return ctx;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
/// Returns the SYCL device associated with an engine.
|
| 104 |
+
///
|
| 105 |
+
/// @param aengine Engine to query.
|
| 106 |
+
///
|
| 107 |
+
/// @returns The underlying SYCL context of the engine.
|
| 108 |
+
inline sycl::device get_device(const engine &aengine) {
|
| 109 |
+
void *dev_ptr;
|
| 110 |
+
error::wrap_c_api(
|
| 111 |
+
dnnl_sycl_interop_engine_get_device(aengine.get(), &dev_ptr),
|
| 112 |
+
"could not get a device handle");
|
| 113 |
+
auto dev = *static_cast<sycl::device *>(dev_ptr);
|
| 114 |
+
return dev;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
/// Creates an execution stream for a given engine associated with a SYCL
|
| 118 |
+
/// queue.
|
| 119 |
+
///
|
| 120 |
+
/// @param aengine Engine object to use for the stream.
|
| 121 |
+
/// @param aqueue SYCL queue to use for the stream.
|
| 122 |
+
///
|
| 123 |
+
/// @returns An execution stream.
|
| 124 |
+
inline stream make_stream(const engine &aengine, sycl::queue &aqueue) {
|
| 125 |
+
dnnl_stream_t astream;
|
| 126 |
+
error::wrap_c_api(
|
| 127 |
+
dnnl_sycl_interop_stream_create(&astream, aengine.get(), &aqueue),
|
| 128 |
+
"could not create a stream");
|
| 129 |
+
return stream(astream);
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
/// Returns the SYCL queue associated with an execution stream.
|
| 133 |
+
///
|
| 134 |
+
/// @param astream Execution stream to query.
|
| 135 |
+
///
|
| 136 |
+
/// @returns SYCL queue object.
|
| 137 |
+
inline sycl::queue get_queue(const stream &astream) {
|
| 138 |
+
void *queue_ptr;
|
| 139 |
+
error::wrap_c_api(
|
| 140 |
+
dnnl_sycl_interop_stream_get_queue(astream.get(), &queue_ptr),
|
| 141 |
+
"could not get a stream handle");
|
| 142 |
+
auto queue = *static_cast<sycl::queue *>(queue_ptr);
|
| 143 |
+
return queue;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
/// Returns the SYCL buffer associated with a memory object.
|
| 147 |
+
///
|
| 148 |
+
/// Throws an exception if the memory allocation kind associated with the
|
| 149 |
+
/// memory object is not equal to dnnl::sycl_interop::memory_kind::buffer.
|
| 150 |
+
///
|
| 151 |
+
/// @tparam T Type of the requested buffer.
|
| 152 |
+
/// @tparam ndims Number of dimensions of the requested buffer.
|
| 153 |
+
/// @param amemory Memory object.
|
| 154 |
+
///
|
| 155 |
+
/// @returns SYCL buffer associated with the memory object.
|
| 156 |
+
template <typename T, int ndims = 1>
|
| 157 |
+
sycl::buffer<T, ndims> get_buffer(const memory &amemory) {
|
| 158 |
+
static_assert(ndims == 1, "only 1D buffers supported");
|
| 159 |
+
|
| 160 |
+
// XXX: workaround: when CPU runtime is not SYCL and amemory was created
|
| 161 |
+
// for CPU engine `get_buffer` should return an error. Use interop API to
|
| 162 |
+
// implement the check.
|
| 163 |
+
dnnl_sycl_interop_memory_kind_t ckind;
|
| 164 |
+
error::wrap_c_api(
|
| 165 |
+
dnnl_sycl_interop_memory_get_memory_kind(amemory.get(), &ckind),
|
| 166 |
+
"could not get SYCL buffer object");
|
| 167 |
+
|
| 168 |
+
void *handle_ptr;
|
| 169 |
+
error::wrap_c_api(dnnl_memory_get_data_handle(amemory.get(), &handle_ptr),
|
| 170 |
+
"could not get SYCL buffer object");
|
| 171 |
+
|
| 172 |
+
// XXX: workaround: zero-range buffer cannot be constructed.
|
| 173 |
+
if (!handle_ptr) return sycl::buffer<T, ndims>(sycl::range<1>(1));
|
| 174 |
+
|
| 175 |
+
auto &buf_u8 = *static_cast<sycl::buffer<uint8_t, 1> *>(handle_ptr);
|
| 176 |
+
|
| 177 |
+
auto range = sycl::range<1>(buf_u8.byte_size() / sizeof(T));
|
| 178 |
+
return buf_u8.reinterpret<T, 1>(range);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
/// Sets SYCL buffer associated with a memory object.
|
| 182 |
+
///
|
| 183 |
+
/// @tparam T Type of the buffer.
|
| 184 |
+
/// @tparam ndims Number of dimensions of the buffer.
|
| 185 |
+
/// @param amemory Memory object to change.
|
| 186 |
+
/// @param abuffer SYCL buffer.
|
| 187 |
+
template <typename T, int ndims>
|
| 188 |
+
void set_buffer(memory &amemory, sycl::buffer<T, ndims> &abuffer) {
|
| 189 |
+
auto range = sycl::range<1>(abuffer.byte_size());
|
| 190 |
+
auto buf_u8 = abuffer.template reinterpret<uint8_t, 1>(range);
|
| 191 |
+
error::wrap_c_api(dnnl_sycl_interop_memory_set_buffer(
|
| 192 |
+
amemory.get(), static_cast<void *>(&buf_u8)),
|
| 193 |
+
"could not set SYCL buffer object");
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
/// Returns the memory allocation kind associated with a memory object.
|
| 197 |
+
///
|
| 198 |
+
/// @param amemory A memory object.
|
| 199 |
+
///
|
| 200 |
+
/// @returns The underlying memory allocation kind of the memory object.
|
| 201 |
+
inline memory_kind get_memory_kind(const memory &amemory) {
|
| 202 |
+
dnnl_sycl_interop_memory_kind_t ckind;
|
| 203 |
+
error::wrap_c_api(
|
| 204 |
+
dnnl_sycl_interop_memory_get_memory_kind(amemory.get(), &ckind),
|
| 205 |
+
"could not get memory kind");
|
| 206 |
+
return static_cast<memory_kind>(ckind);
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
#ifdef DNNL_EXPERIMENTAL_SPARSE
|
| 210 |
+
/// Creates a memory object with multiple handles.
|
| 211 |
+
///
|
| 212 |
+
/// @param memory_desc Memory descriptor.
|
| 213 |
+
/// @param aengine Engine to use.
|
| 214 |
+
/// @param kind Memory allocation kind to specify the type of handles.
|
| 215 |
+
/// @param handles Handles of the memory buffers to use as underlying storages.
|
| 216 |
+
/// For each element of the @p handles array the following applies:
|
| 217 |
+
/// - A USM pointer to the user-allocated buffer. In this case the library
|
| 218 |
+
/// doesn't own the buffer. Requires @p memory_kind to be equal to
|
| 219 |
+
/// dnnl::sycl_interop::memory_kind::usm.
|
| 220 |
+
/// - A pointer to SYCL buffer. In this case the library doesn't own the
|
| 221 |
+
/// buffer. Requires @p memory_kind be equal to be equal to
|
| 222 |
+
/// dnnl::sycl_interop::memory_kind::buffer.
|
| 223 |
+
/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
| 224 |
+
/// allocate the buffer that corresponds to the memory allocation kind
|
| 225 |
+
/// @p memory_kind for the memory object. In this case the library
|
| 226 |
+
/// owns the buffer.
|
| 227 |
+
/// - The DNNL_MEMORY_NONE specific value. Instructs the library to
|
| 228 |
+
/// create memory object without an underlying buffer.
|
| 229 |
+
///
|
| 230 |
+
/// If the @p handles vector is not provided the library will allocate all
|
| 231 |
+
/// buffers as if all handles have the special value DNNL_MEMORY_ALLOCATE.
|
| 232 |
+
///
|
| 233 |
+
/// @returns Created memory object.
|
| 234 |
+
inline memory make_memory(const memory::desc &memory_desc,
|
| 235 |
+
const engine &aengine, memory_kind kind,
|
| 236 |
+
std::vector<void *> handles = {}) {
|
| 237 |
+
if (handles.empty()) {
|
| 238 |
+
const int nhandles = memory_desc.get_num_handles();
|
| 239 |
+
handles.resize(nhandles, DNNL_MEMORY_ALLOCATE);
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
dnnl_memory_t c_memory;
|
| 243 |
+
error::wrap_c_api(
|
| 244 |
+
dnnl_sycl_interop_memory_create_v2(&c_memory, memory_desc.get(),
|
| 245 |
+
aengine.get(), convert_to_c(kind), (int)handles.size(),
|
| 246 |
+
handles.data()),
|
| 247 |
+
"could not create a memory");
|
| 248 |
+
return memory(c_memory);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
/// Creates a memory object.
|
| 252 |
+
///
|
| 253 |
+
/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the
|
| 254 |
+
/// constructed memory object will have the underlying buffer set. In this
|
| 255 |
+
/// case, the buffer will be initialized as if:
|
| 256 |
+
/// - dnnl::memory::set_data_handle() had been called, if @p memory_kind is
|
| 257 |
+
/// equal to dnnl::sycl_interop::memory_kind::usm, or
|
| 258 |
+
/// - dnnl::sycl_interop::set_buffer() has been called, if @p memory_kind is
|
| 259 |
+
/// equal to dnnl::sycl_interop::memory_kind::buffer.
|
| 260 |
+
///
|
| 261 |
+
/// @param memory_desc Memory descriptor.
|
| 262 |
+
/// @param aengine Engine to use.
|
| 263 |
+
/// @param kind Memory allocation kind to specify the type of handle.
|
| 264 |
+
/// @param handle Handle of the memory buffer to use as an underlying storage.
|
| 265 |
+
/// - A USM pointer to the user-allocated buffer. In this case the library
|
| 266 |
+
/// doesn't own the buffer. Requires @p memory_kind to be equal to
|
| 267 |
+
/// dnnl::sycl_interop::memory_kind::usm.
|
| 268 |
+
/// - A pointer to SYCL buffer. In this case the library doesn't own the
|
| 269 |
+
/// buffer. Requires @p memory_kind be equal to be equal to
|
| 270 |
+
/// dnnl::sycl_interop::memory_kind::buffer.
|
| 271 |
+
/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
| 272 |
+
/// allocate the buffer that corresponds to the memory allocation kind
|
| 273 |
+
/// @p memory_kind for the memory object. In this case the library
|
| 274 |
+
/// owns the buffer.
|
| 275 |
+
/// - The DNNL_MEMORY_NONE specific value. Instructs the library to
|
| 276 |
+
/// create memory object without an underlying buffer.
|
| 277 |
+
///
|
| 278 |
+
/// @returns Created memory object.
|
| 279 |
+
inline memory make_memory(const memory::desc &memory_desc,
|
| 280 |
+
const engine &aengine, memory_kind kind, void *handle) {
|
| 281 |
+
return make_memory(
|
| 282 |
+
memory_desc, aengine, kind, std::vector<void *> {handle});
|
| 283 |
+
}
|
| 284 |
+
#else
|
| 285 |
+
|
| 286 |
+
/// Creates a memory object.
|
| 287 |
+
///
|
| 288 |
+
/// Unless @p handle is equal to DNNL_MEMORY_NONE or DNNL_MEMORY_ALLOCATE, the
|
| 289 |
+
/// constructed memory object will have the underlying buffer set. In this
|
| 290 |
+
/// case, the buffer will be initialized as if:
|
| 291 |
+
/// - dnnl::memory::set_data_handle() had been called, if @p memory_kind is
|
| 292 |
+
/// equal to dnnl::sycl_interop::memory_kind::usm, or
|
| 293 |
+
/// - dnnl::sycl_interop::set_buffer() has been called, if @p memory_kind is
|
| 294 |
+
/// equal to dnnl::sycl_interop::memory_kind::buffer.
|
| 295 |
+
///
|
| 296 |
+
/// @param memory_desc Memory descriptor.
|
| 297 |
+
/// @param aengine Engine to use.
|
| 298 |
+
/// @param kind Memory allocation kind to specify the type of handle.
|
| 299 |
+
/// @param handle Handle of the memory buffer to use as an underlying storage.
|
| 300 |
+
/// - A USM pointer to the user-allocated buffer. In this case the library
|
| 301 |
+
/// doesn't own the buffer. Requires @p memory_kind to be equal to
|
| 302 |
+
/// dnnl::sycl_interop::memory_kind::usm.
|
| 303 |
+
/// - A pointer to SYCL buffer. In this case the library doesn't own the
|
| 304 |
+
/// buffer. Requires @p memory_kind be equal to be equal to
|
| 305 |
+
/// dnnl::sycl_interop::memory_kind::buffer.
|
| 306 |
+
/// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
| 307 |
+
/// allocate the buffer that corresponds to the memory allocation kind
|
| 308 |
+
/// @p memory_kind for the memory object. In this case the library
|
| 309 |
+
/// owns the buffer.
|
| 310 |
+
/// - The DNNL_MEMORY_NONE specific value. Instructs the library to
|
| 311 |
+
/// create memory object without an underlying buffer.
|
| 312 |
+
///
|
| 313 |
+
/// @returns Created memory object.
|
| 314 |
+
inline memory make_memory(const memory::desc &memory_desc,
|
| 315 |
+
const engine &aengine, memory_kind kind,
|
| 316 |
+
void *handle = DNNL_MEMORY_ALLOCATE) {
|
| 317 |
+
dnnl_memory_t c_memory;
|
| 318 |
+
error::wrap_c_api(
|
| 319 |
+
dnnl_sycl_interop_memory_create(&c_memory, memory_desc.get(),
|
| 320 |
+
aengine.get(), convert_to_c(kind), handle),
|
| 321 |
+
"could not create a memory");
|
| 322 |
+
return memory(c_memory);
|
| 323 |
+
}
|
| 324 |
+
#endif
|
| 325 |
+
|
| 326 |
+
/// Constructs a memory object from a SYCL buffer.
|
| 327 |
+
///
|
| 328 |
+
/// @param memory_desc Memory descriptor.
|
| 329 |
+
/// @param aengine Engine to use.
|
| 330 |
+
/// @param abuffer A SYCL buffer to use.
|
| 331 |
+
///
|
| 332 |
+
/// @returns Created memory object.
|
| 333 |
+
template <typename T, int ndims = 1>
|
| 334 |
+
memory make_memory(const memory::desc &memory_desc, const engine &aengine,
|
| 335 |
+
sycl::buffer<T, ndims> &abuffer) {
|
| 336 |
+
memory amemory(memory_desc, aengine, DNNL_MEMORY_NONE);
|
| 337 |
+
set_buffer(amemory, abuffer);
|
| 338 |
+
return amemory;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
/// Executes computations specified by the primitive in a specified stream and
|
| 342 |
+
/// returns a SYCL event.
|
| 343 |
+
///
|
| 344 |
+
/// Arguments are passed via an arguments map containing
|
| 345 |
+
/// <index, memory object> pairs. The index must be one of the `DNNL_ARG_*`
|
| 346 |
+
/// values such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor
|
| 347 |
+
/// matching the one returned by
|
| 348 |
+
/// #dnnl::primitive_desc::query_md(#query::exec_arg_md, index) unless using
|
| 349 |
+
/// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL).
|
| 350 |
+
///
|
| 351 |
+
/// @param aprimitive Primitive to execute.
|
| 352 |
+
/// @param astream Stream object. The stream must belong to the same engine
|
| 353 |
+
/// as the primitive.
|
| 354 |
+
/// @param args Arguments map.
|
| 355 |
+
/// @param deps Optional vector with `sycl::event` dependencies.
|
| 356 |
+
///
|
| 357 |
+
/// @returns Output event.
|
| 358 |
+
inline sycl::event execute(const dnnl::primitive &aprimitive,
|
| 359 |
+
const stream &astream, const std::unordered_map<int, memory> &args,
|
| 360 |
+
const std::vector<sycl::event> &deps = {}) {
|
| 361 |
+
std::vector<dnnl_exec_arg_t> c_args;
|
| 362 |
+
c_args.reserve(args.size());
|
| 363 |
+
for (const auto &a : args)
|
| 364 |
+
c_args.push_back({a.first, a.second.get()});
|
| 365 |
+
|
| 366 |
+
sycl::event return_event;
|
| 367 |
+
error::wrap_c_api(
|
| 368 |
+
dnnl_sycl_interop_primitive_execute(aprimitive.get(), astream.get(),
|
| 369 |
+
(int)c_args.size(), c_args.data(), &deps, &return_event),
|
| 370 |
+
"could not execute a primitive");
|
| 371 |
+
return return_event;
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
} // namespace sycl_interop
|
| 375 |
+
|
| 376 |
+
/// @} dnnl_api_sycl_interop
|
| 377 |
+
|
| 378 |
+
/// @} dnnl_api_interop
|
| 379 |
+
|
| 380 |
+
} // namespace dnnl
|
| 381 |
+
|
| 382 |
+
/// @} dnnl_api
|
| 383 |
+
|
| 384 |
+
#endif // DNNL_SYCL_HPP
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_sycl_types.h
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2020-2021 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
#ifndef ONEAPI_DNNL_DNNL_SYCL_TYPES_H
|
| 18 |
+
#define ONEAPI_DNNL_DNNL_SYCL_TYPES_H
|
| 19 |
+
|
| 20 |
+
#ifdef __cplusplus
|
| 21 |
+
extern "C" {
|
| 22 |
+
#endif
|
| 23 |
+
|
| 24 |
+
/// @addtogroup dnnl_api
|
| 25 |
+
/// @{
|
| 26 |
+
|
| 27 |
+
/// @addtogroup dnnl_api_interop
|
| 28 |
+
/// @{
|
| 29 |
+
|
| 30 |
+
/// @addtogroup dnnl_api_sycl_interop
|
| 31 |
+
/// @{
|
| 32 |
+
|
| 33 |
+
/// Memory allocation kind.
|
| 34 |
+
typedef enum {
|
| 35 |
+
/// USM (device, shared, host, or unknown) memory allocation kind - default.
|
| 36 |
+
dnnl_sycl_interop_usm,
|
| 37 |
+
/// Buffer memory allocation kind.
|
| 38 |
+
dnnl_sycl_interop_buffer,
|
| 39 |
+
} dnnl_sycl_interop_memory_kind_t;
|
| 40 |
+
|
| 41 |
+
/// @} dnnl_api_sycl_interop
|
| 42 |
+
|
| 43 |
+
/// @} dnnl_api_interop
|
| 44 |
+
|
| 45 |
+
/// @} dnnl_api
|
| 46 |
+
|
| 47 |
+
#ifdef __cplusplus
|
| 48 |
+
}
|
| 49 |
+
#endif
|
| 50 |
+
|
| 51 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool.h
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2020-2022 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
#ifndef ONEAPI_DNNL_DNNL_THREADPOOL_H
|
| 18 |
+
#define ONEAPI_DNNL_DNNL_THREADPOOL_H
|
| 19 |
+
|
| 20 |
+
#include "oneapi/dnnl/dnnl_config.h"
|
| 21 |
+
#include "oneapi/dnnl/dnnl_types.h"
|
| 22 |
+
|
| 23 |
+
#ifdef __cplusplus
|
| 24 |
+
extern "C" {
|
| 25 |
+
#endif
|
| 26 |
+
|
| 27 |
+
/// @addtogroup dnnl_api
|
| 28 |
+
/// @{
|
| 29 |
+
|
| 30 |
+
/// @addtogroup dnnl_api_interop
|
| 31 |
+
/// @{
|
| 32 |
+
|
| 33 |
+
/// @addtogroup dnnl_api_threadpool_interop
|
| 34 |
+
/// @{
|
| 35 |
+
|
| 36 |
+
/// Creates an execution stream with specified threadpool.
|
| 37 |
+
///
|
| 38 |
+
/// @sa @ref dev_guide_threadpool
|
| 39 |
+
///
|
| 40 |
+
/// @param stream Output execution stream.
|
| 41 |
+
/// @param engine Engine to create the execution stream on.
|
| 42 |
+
/// @param threadpool Pointer to an instance of a C++ class that implements
|
| 43 |
+
/// dnnl::threapdool_iface interface.
|
| 44 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 45 |
+
/// otherwise.
|
| 46 |
+
dnnl_status_t DNNL_API dnnl_threadpool_interop_stream_create(
|
| 47 |
+
dnnl_stream_t *stream, dnnl_engine_t engine, void *threadpool);
|
| 48 |
+
|
| 49 |
+
/// Returns a threadpool to be used by the execution stream.
|
| 50 |
+
///
|
| 51 |
+
/// @sa @ref dev_guide_threadpool
|
| 52 |
+
///
|
| 53 |
+
/// @param astream Execution stream.
|
| 54 |
+
/// @param threadpool Output pointer to an instance of a C++ class that
|
| 55 |
+
/// implements dnnl::threapdool_iface interface. Set to NULL if the
|
| 56 |
+
/// stream was created without threadpool.
|
| 57 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 58 |
+
/// otherwise.
|
| 59 |
+
dnnl_status_t DNNL_API dnnl_threadpool_interop_stream_get_threadpool(
|
| 60 |
+
dnnl_stream_t astream, void **threadpool);
|
| 61 |
+
|
| 62 |
+
/// Sets the maximum concurrency assumed by oneDNN when outside a
|
| 63 |
+
/// parallel call.
|
| 64 |
+
///
|
| 65 |
+
/// @param max_concurrency The maximum concurrency assumed by oneDNN
|
| 66 |
+
/// when outside a parallel call. This is a threadlocal setting.
|
| 67 |
+
/// @returns #dnnl_success on success and a status describing the
|
| 68 |
+
/// error otherwise.
|
| 69 |
+
dnnl_status_t DNNL_API dnnl_threadpool_interop_set_max_concurrency(
|
| 70 |
+
int max_concurrency);
|
| 71 |
+
|
| 72 |
+
/// Gets the maximum concurrency assumed by oneDNN when outside a
|
| 73 |
+
/// parallel call.
|
| 74 |
+
///
|
| 75 |
+
/// @param max_concurrency The maximum concurrency assumed by oneDNN
|
| 76 |
+
/// when outside a parallel call. This is a threadlocal setting.
|
| 77 |
+
/// @returns #dnnl_success on success and a status describing the
|
| 78 |
+
/// error otherwise.
|
| 79 |
+
dnnl_status_t DNNL_API dnnl_threadpool_interop_get_max_concurrency(
|
| 80 |
+
int *max_concurrency);
|
| 81 |
+
|
| 82 |
+
/// @copydoc dnnl_sgemm()
|
| 83 |
+
/// @param threadpool A pointer to a threadpool interface (only when built with
|
| 84 |
+
/// the THREADPOOL CPU runtime).
|
| 85 |
+
dnnl_status_t DNNL_API dnnl_threadpool_interop_sgemm(char transa, char transb,
|
| 86 |
+
dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A,
|
| 87 |
+
dnnl_dim_t lda, const float *B, dnnl_dim_t ldb, float beta, float *C,
|
| 88 |
+
dnnl_dim_t ldc, void *threadpool);
|
| 89 |
+
|
| 90 |
+
/// @copydoc dnnl_gemm_u8s8s32()
|
| 91 |
+
/// @param threadpool A pointer to a threadpool interface (only when built with
|
| 92 |
+
/// the THREADPOOL CPU runtime).
|
| 93 |
+
dnnl_status_t DNNL_API dnnl_threadpool_interop_gemm_u8s8s32(char transa,
|
| 94 |
+
char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K,
|
| 95 |
+
float alpha, const uint8_t *A, dnnl_dim_t lda, uint8_t ao,
|
| 96 |
+
const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C,
|
| 97 |
+
dnnl_dim_t ldc, const int32_t *co, void *threadpool);
|
| 98 |
+
|
| 99 |
+
/// @copydoc dnnl_gemm_s8s8s32()
|
| 100 |
+
/// @param threadpool A pointer to a threadpool interface (only when built with
|
| 101 |
+
/// the THREADPOOL CPU runtime).
|
| 102 |
+
dnnl_status_t DNNL_API dnnl_threadpool_interop_gemm_s8s8s32(char transa,
|
| 103 |
+
char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K,
|
| 104 |
+
float alpha, const int8_t *A, dnnl_dim_t lda, int8_t ao,
|
| 105 |
+
const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C,
|
| 106 |
+
dnnl_dim_t ldc, const int32_t *co, void *threadpool);
|
| 107 |
+
|
| 108 |
+
/// @} dnnl_api_threadpool_interop
|
| 109 |
+
|
| 110 |
+
/// @} dnnl_api_interop
|
| 111 |
+
|
| 112 |
+
/// @} dnnl_api
|
| 113 |
+
|
| 114 |
+
#ifdef __cplusplus
|
| 115 |
+
}
|
| 116 |
+
#endif
|
| 117 |
+
|
| 118 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool.hpp
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2020-2025 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
#ifndef ONEAPI_DNNL_DNNL_THREADPOOL_HPP
|
| 18 |
+
#define ONEAPI_DNNL_DNNL_THREADPOOL_HPP
|
| 19 |
+
|
| 20 |
+
#include "oneapi/dnnl/dnnl.hpp"
|
| 21 |
+
#include "oneapi/dnnl/dnnl_threadpool.h"
|
| 22 |
+
|
| 23 |
+
#include "oneapi/dnnl/dnnl_threadpool_iface.hpp"
|
| 24 |
+
|
| 25 |
+
/// @addtogroup dnnl_api
|
| 26 |
+
/// @{
|
| 27 |
+
|
| 28 |
+
namespace dnnl {
|
| 29 |
+
|
| 30 |
+
/// @addtogroup dnnl_api_interop
|
| 31 |
+
/// @{
|
| 32 |
+
|
| 33 |
+
/// @addtogroup dnnl_api_threadpool_interop Threadpool interoperability API
|
| 34 |
+
/// API extensions to interact with the underlying Threadpool run-time.
|
| 35 |
+
/// @{
|
| 36 |
+
|
| 37 |
+
/// Threadpool interoperability namespace
|
| 38 |
+
namespace threadpool_interop {
|
| 39 |
+
|
| 40 |
+
/// Constructs an execution stream for the specified engine and threadpool.
|
| 41 |
+
///
|
| 42 |
+
/// @sa @ref dev_guide_threadpool
|
| 43 |
+
///
|
| 44 |
+
/// @param aengine Engine to create the stream on.
|
| 45 |
+
/// @param threadpool Pointer to an instance of a C++ class that implements
|
| 46 |
+
/// dnnl::threapdool_iface interface.
|
| 47 |
+
/// @returns An execution stream.
|
| 48 |
+
inline dnnl::stream make_stream(
|
| 49 |
+
const dnnl::engine &aengine, threadpool_iface *threadpool) {
|
| 50 |
+
dnnl_stream_t c_stream;
|
| 51 |
+
dnnl::error::wrap_c_api(dnnl_threadpool_interop_stream_create(
|
| 52 |
+
&c_stream, aengine.get(), threadpool),
|
| 53 |
+
"could not create stream");
|
| 54 |
+
return dnnl::stream(c_stream);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
/// Returns the pointer to a threadpool that is used by an execution stream.
|
| 58 |
+
///
|
| 59 |
+
/// @sa @ref dev_guide_threadpool
|
| 60 |
+
///
|
| 61 |
+
/// @param astream An execution stream.
|
| 62 |
+
/// @returns Output pointer to an instance of a C++ class that implements
|
| 63 |
+
/// dnnl::threapdool_iface interface or NULL if the stream was created
|
| 64 |
+
/// without threadpool.
|
| 65 |
+
inline threadpool_iface *get_threadpool(const dnnl::stream &astream) {
|
| 66 |
+
void *tp;
|
| 67 |
+
dnnl::error::wrap_c_api(
|
| 68 |
+
dnnl_threadpool_interop_stream_get_threadpool(astream.get(), &tp),
|
| 69 |
+
"could not get stream threadpool");
|
| 70 |
+
return static_cast<threadpool_iface *>(tp);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
/// @copydoc dnnl_threadpool_interop_sgemm()
|
| 74 |
+
inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
|
| 75 |
+
dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
|
| 76 |
+
const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc,
|
| 77 |
+
threadpool_iface *threadpool) {
|
| 78 |
+
return static_cast<status>(dnnl_threadpool_interop_sgemm(transa, transb, M,
|
| 79 |
+
N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool));
|
| 80 |
+
}
|
| 81 |
+
/// @copydoc dnnl_threadpool_interop_gemm_u8s8s32()
|
| 82 |
+
inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
|
| 83 |
+
dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
|
| 84 |
+
dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
|
| 85 |
+
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co,
|
| 86 |
+
threadpool_iface *threadpool) {
|
| 87 |
+
return static_cast<status>(dnnl_threadpool_interop_gemm_u8s8s32(transa,
|
| 88 |
+
transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo, beta, C,
|
| 89 |
+
ldc, co, threadpool));
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
/// @copydoc dnnl_threadpool_interop_gemm_s8s8s32()
|
| 93 |
+
inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
|
| 94 |
+
dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
|
| 95 |
+
dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
|
| 96 |
+
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co,
|
| 97 |
+
threadpool_iface *threadpool) {
|
| 98 |
+
return static_cast<status>(dnnl_threadpool_interop_gemm_s8s8s32(transa,
|
| 99 |
+
transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo, beta, C,
|
| 100 |
+
ldc, co, threadpool));
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
} // namespace threadpool_interop
|
| 104 |
+
|
| 105 |
+
/// @} dnnl_api_threadpool_interop
|
| 106 |
+
|
| 107 |
+
/// @} dnnl_api_interop
|
| 108 |
+
|
| 109 |
+
} // namespace dnnl
|
| 110 |
+
|
| 111 |
+
/// @} dnnl_api
|
| 112 |
+
|
| 113 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_threadpool_iface.hpp
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2020-2024 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
#ifndef ONEAPI_DNNL_DNNL_THREADPOOL_IFACE_HPP
|
| 18 |
+
#define ONEAPI_DNNL_DNNL_THREADPOOL_IFACE_HPP
|
| 19 |
+
|
| 20 |
+
#include <cstdint>
|
| 21 |
+
#include <functional>
|
| 22 |
+
|
| 23 |
+
/// @addtogroup dnnl_api
|
| 24 |
+
/// @{
|
| 25 |
+
|
| 26 |
+
namespace dnnl {
|
| 27 |
+
|
| 28 |
+
/// @addtogroup dnnl_api_interop
|
| 29 |
+
/// @{
|
| 30 |
+
|
| 31 |
+
/// @addtogroup dnnl_api_threadpool_interop
|
| 32 |
+
/// @{
|
| 33 |
+
|
| 34 |
+
namespace threadpool_interop {
|
| 35 |
+
|
| 36 |
+
/// Abstract threadpool interface. The users are expected to subclass this
|
| 37 |
+
/// interface and pass an object to the library during CPU stream creation or
|
| 38 |
+
/// directly in case of BLAS functions.
|
| 39 |
+
struct threadpool_iface {
|
| 40 |
+
/// Returns the number of worker threads.
|
| 41 |
+
virtual int get_num_threads() const = 0;
|
| 42 |
+
|
| 43 |
+
/// Returns true if the calling thread belongs to this threadpool.
|
| 44 |
+
virtual bool get_in_parallel() const = 0;
|
| 45 |
+
|
| 46 |
+
/// Submits n instances of a closure for execution in parallel:
|
| 47 |
+
///
|
| 48 |
+
/// for (int i = 0; i < n; i++) fn(i, n);
|
| 49 |
+
///
|
| 50 |
+
virtual void parallel_for(int n, const std::function<void(int, int)> &fn)
|
| 51 |
+
= 0;
|
| 52 |
+
|
| 53 |
+
/// Returns threadpool behavior flags bit mask (see below).
|
| 54 |
+
virtual uint64_t get_flags() const = 0;
|
| 55 |
+
|
| 56 |
+
/// If set, parallel_for() returns immediately and oneDNN needs implement
|
| 57 |
+
/// waiting for the submitted closures to finish execution on its own.
|
| 58 |
+
static constexpr uint64_t ASYNCHRONOUS = 1;
|
| 59 |
+
|
| 60 |
+
virtual ~threadpool_iface() {}
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
} // namespace threadpool_interop
|
| 64 |
+
|
| 65 |
+
/// @} dnnl_api_threadpool_interop
|
| 66 |
+
|
| 67 |
+
/// @} dnnl_api_interop
|
| 68 |
+
|
| 69 |
+
} // namespace dnnl
|
| 70 |
+
|
| 71 |
+
/// @} dnnl_api
|
| 72 |
+
|
| 73 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_types.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel.h
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2024 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
/// @file
|
| 18 |
+
/// ukernel C API
|
| 19 |
+
|
| 20 |
+
#ifndef ONEAPI_DNNL_DNNL_UKERNEL_H
|
| 21 |
+
#define ONEAPI_DNNL_DNNL_UKERNEL_H
|
| 22 |
+
|
| 23 |
+
#include "oneapi/dnnl/dnnl.h"
|
| 24 |
+
#include "oneapi/dnnl/dnnl_ukernel_types.h"
|
| 25 |
+
|
| 26 |
+
#ifdef __cplusplus
|
| 27 |
+
extern "C" {
|
| 28 |
+
#endif
|
| 29 |
+
|
| 30 |
+
/// @addtogroup dnnl_api
|
| 31 |
+
/// @{
|
| 32 |
+
|
| 33 |
+
/// @addtogroup dnnl_api_ukernel
|
| 34 |
+
/// @{
|
| 35 |
+
|
| 36 |
+
#ifdef DNNL_EXPERIMENTAL_UKERNEL
|
| 37 |
+
|
| 38 |
+
/// Creates a ukernel attributes memory storage.
|
| 39 |
+
///
|
| 40 |
+
/// @param attr_params Output ukernel attributes memory storage.
|
| 41 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 42 |
+
/// otherwise.
|
| 43 |
+
dnnl_status_t DNNL_API dnnl_ukernel_attr_params_create(
|
| 44 |
+
dnnl_ukernel_attr_params_t *attr_params);
|
| 45 |
+
|
| 46 |
+
/// Sets post-operations arguments to a storage.
|
| 47 |
+
///
|
| 48 |
+
/// @param attr_params Memory pointers storage object.
|
| 49 |
+
/// @param post_ops_args A pointer to pointers of post_ops storages. Expected to
|
| 50 |
+
/// be packed together.
|
| 51 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 52 |
+
/// otherwise.
|
| 53 |
+
dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_post_ops_args(
|
| 54 |
+
dnnl_ukernel_attr_params_t attr_params, const void **post_ops_args);
|
| 55 |
+
|
| 56 |
+
/// Sets tensor A scales argument to a storage.
|
| 57 |
+
///
|
| 58 |
+
/// @param attr_params Memory pointers storage object.
|
| 59 |
+
/// @param a_scales Pointer to the scales storage.
|
| 60 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 61 |
+
/// otherwise.
|
| 62 |
+
dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_A_scales(
|
| 63 |
+
dnnl_ukernel_attr_params_t attr_params, const void *a_scales);
|
| 64 |
+
|
| 65 |
+
/// Sets tensor B scales argument to a storage.
|
| 66 |
+
///
|
| 67 |
+
/// If `dnnl_brgemm_set_B_scales` used mask of 2, then at least N values of
|
| 68 |
+
/// selected data type are expected.
|
| 69 |
+
///
|
| 70 |
+
/// @param attr_params Memory pointers storage object.
|
| 71 |
+
/// @param b_scales Pointer to the scales storage.
|
| 72 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 73 |
+
/// otherwise.
|
| 74 |
+
dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_B_scales(
|
| 75 |
+
dnnl_ukernel_attr_params_t attr_params, const void *b_scales);
|
| 76 |
+
|
| 77 |
+
/// Sets tensor D scales argument to a storage.
|
| 78 |
+
///
|
| 79 |
+
/// @param attr_params Memory pointers storage object.
|
| 80 |
+
/// @param d_scales Pointer to the scales storage.
|
| 81 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 82 |
+
/// otherwise.
|
| 83 |
+
dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_D_scales(
|
| 84 |
+
dnnl_ukernel_attr_params_t attr_params, const void *d_scales);
|
| 85 |
+
|
| 86 |
+
/// Destroys a ukernel attributes memory storage.
|
| 87 |
+
///
|
| 88 |
+
/// @param attr_params Memory pointers storage object to destroy.
|
| 89 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 90 |
+
/// otherwise.
|
| 91 |
+
dnnl_status_t DNNL_API dnnl_ukernel_attr_params_destroy(
|
| 92 |
+
dnnl_ukernel_attr_params_t attr_params);
|
| 93 |
+
|
| 94 |
+
/// @addtogroup dnnl_api_ukernel_brgemm
|
| 95 |
+
/// @{
|
| 96 |
+
|
| 97 |
+
/// Creates a BRGeMM ukernel object. Operates by the following formula:
|
| 98 |
+
/// `C = [A x B]`.
|
| 99 |
+
///
|
| 100 |
+
/// @param brgemm Output BRGeMM ukernel object.
|
| 101 |
+
/// @param M Dimension M of tensor A.
|
| 102 |
+
/// @param N Dimension N of tensor B.
|
| 103 |
+
/// @param K Dimension K of tensors A and B.
|
| 104 |
+
/// @param batch_size Number of batches to process.
|
| 105 |
+
/// @param lda Leading dimension of tensor A.
|
| 106 |
+
/// @param ldb Leading dimension of tensor B.
|
| 107 |
+
/// @param ldc Leading dimension of tensor C.
|
| 108 |
+
/// @param a_dt Data type of tensor A.
|
| 109 |
+
/// @param b_dt Data type of tensor B.
|
| 110 |
+
/// @param c_dt Data type of tensor C. Must be dnnl_f32.
|
| 111 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 112 |
+
/// otherwise.
|
| 113 |
+
dnnl_status_t DNNL_API dnnl_brgemm_create(dnnl_brgemm_t *brgemm, dnnl_dim_t M,
|
| 114 |
+
dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t batch_size, dnnl_dim_t lda,
|
| 115 |
+
dnnl_dim_t ldb, dnnl_dim_t ldc, dnnl_data_type_t a_dt,
|
| 116 |
+
dnnl_data_type_t b_dt, dnnl_data_type_t c_dt);
|
| 117 |
+
|
| 118 |
+
/// Sets adding an intermediate result to the output tensor C instead of
|
| 119 |
+
/// writing: `C += [A x B]`.
|
| 120 |
+
///
|
| 121 |
+
/// @param brgemm BRGeMM ukernel object.
|
| 122 |
+
/// @param add_C Value to indicate addition. Can be `0` to skip addition, and
|
| 123 |
+
/// `1` to apply addition.
|
| 124 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 125 |
+
/// otherwise.
|
| 126 |
+
dnnl_status_t DNNL_API dnnl_brgemm_set_add_C(dnnl_brgemm_t brgemm, int add_C);
|
| 127 |
+
|
| 128 |
+
/// Sets post-operations to a BRGeMM ukernel object: `D = post-operations(C)`.
|
| 129 |
+
///
|
| 130 |
+
/// Post-operations applies if one of the following holds:
|
| 131 |
+
/// * Non-empty attributes are specified.
|
| 132 |
+
/// * Output data type `d_dt` is different from accumulation data type `c_dt`.
|
| 133 |
+
///
|
| 134 |
+
/// If any of conditions happens, the final call of the accumulation chain
|
| 135 |
+
/// must be `dnnl_brgemm_execute_postops`, and `dnnl_brgemm_execute`, otherwise.
|
| 136 |
+
///
|
| 137 |
+
/// @param brgemm BRGeMM ukernel object.
|
| 138 |
+
/// @param ldd Leading dimension of tensor D.
|
| 139 |
+
/// @param d_dt Data type of tensor D.
|
| 140 |
+
/// @param post_ops Primitive post operations attribute to extend the kernel
|
| 141 |
+
/// operations.
|
| 142 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 143 |
+
/// otherwise.
|
| 144 |
+
dnnl_status_t DNNL_API dnnl_brgemm_set_post_ops(dnnl_brgemm_t brgemm,
|
| 145 |
+
dnnl_dim_t ldd, dnnl_data_type_t d_dt, const_dnnl_post_ops_t post_ops);
|
| 146 |
+
|
| 147 |
+
/// Sets tensor A scales mask to a BRGeMM ukernel object.
|
| 148 |
+
///
|
| 149 |
+
/// For quantization flavor tensor A scales apply to accumulation buffer once C
|
| 150 |
+
/// is ready.
|
| 151 |
+
///
|
| 152 |
+
/// @param brgemm BRGeMM ukernel object.
|
| 153 |
+
/// @param a_scale_mask Tensor A scale mask. Can be `0` only.
|
| 154 |
+
dnnl_status_t DNNL_API dnnl_brgemm_set_A_scales(
|
| 155 |
+
dnnl_brgemm_t brgemm, int a_scale_mask);
|
| 156 |
+
|
| 157 |
+
/// Sets tensor B scales mask to a BRGeMM ukernel object.
|
| 158 |
+
///
|
| 159 |
+
/// For quantization flavor tensor B scales apply to accumulation buffer once C
|
| 160 |
+
/// is ready.
|
| 161 |
+
///
|
| 162 |
+
/// @param brgemm BRGeMM ukernel object.
|
| 163 |
+
/// @param b_scale_mask Tensor B scale mask. Can be `0` and `2` only.
|
| 164 |
+
dnnl_status_t DNNL_API dnnl_brgemm_set_B_scales(
|
| 165 |
+
dnnl_brgemm_t brgemm, int b_scale_mask);
|
| 166 |
+
|
| 167 |
+
/// Sets tensor D scales mask to a BRGeMM ukernel object.
|
| 168 |
+
///
|
| 169 |
+
/// For quantization flavor tensor D scales apply after all post-ops are
|
| 170 |
+
/// applied.
|
| 171 |
+
///
|
| 172 |
+
/// @param brgemm BRGeMM ukernel object.
|
| 173 |
+
/// @param d_scale_mask Tensor D scale mask. Can be `0` only.
|
| 174 |
+
dnnl_status_t DNNL_API dnnl_brgemm_set_D_scales(
|
| 175 |
+
dnnl_brgemm_t brgemm, int d_scale_mask);
|
| 176 |
+
|
| 177 |
+
/// Finalizes initialization of a BRGeMM ukernel object.
|
| 178 |
+
///
|
| 179 |
+
/// This step is mandatory to query information from the object.
|
| 180 |
+
///
|
| 181 |
+
/// @param brgemm Output BRGeMM ukernel object.
|
| 182 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 183 |
+
/// otherwise.
|
| 184 |
+
dnnl_status_t DNNL_API dnnl_brgemm_finalize(dnnl_brgemm_t brgemm);
|
| 185 |
+
|
| 186 |
+
/// Returns the packing type expected by a tensor B of a BRGeMM ukernel object.
|
| 187 |
+
///
|
| 188 |
+
/// @param brgemm BRGeMM ukernel object.
|
| 189 |
+
/// @param pack_type Output packing type. Can be `dnnl_brgemm_no_pack` if
|
| 190 |
+
/// packing is not expected, and `dnnl_brgemm_pack_32`, otherwise.
|
| 191 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 192 |
+
/// otherwise.
|
| 193 |
+
dnnl_status_t DNNL_API dnnl_brgemm_get_B_pack_type(
|
| 194 |
+
const_dnnl_brgemm_t brgemm, dnnl_pack_type_t *pack_type);
|
| 195 |
+
|
| 196 |
+
/// Returns the size of a scratchpad memory needed for the BRGeMM ukernel
|
| 197 |
+
/// object.
|
| 198 |
+
///
|
| 199 |
+
/// @param brgemm BRGeMM ukernel object.
|
| 200 |
+
/// @param size Output size of a buffer required for the BRGeMM ukernel object.
|
| 201 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 202 |
+
/// otherwise.
|
| 203 |
+
dnnl_status_t DNNL_API dnnl_brgemm_get_scratchpad_size(
|
| 204 |
+
const_dnnl_brgemm_t brgemm, size_t *size);
|
| 205 |
+
|
| 206 |
+
/// Returns the flag indicating when the call to `dnnl_brgemm_execute_postops`
|
| 207 |
+
/// is valid.
|
| 208 |
+
///
|
| 209 |
+
/// @param brgemm BRGeMM ukernel object.
|
| 210 |
+
/// @param valid The flag indicating if `dnnl_brgemm_execute_postops` is valid
|
| 211 |
+
/// for a given ukernel object. `1` is for valid and `0`, otherwise.
|
| 212 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 213 |
+
/// otherwise.
|
| 214 |
+
dnnl_status_t DNNL_API dnnl_brgemm_is_execute_postops_valid(
|
| 215 |
+
const_dnnl_brgemm_t brgemm, int *valid);
|
| 216 |
+
|
| 217 |
+
/// Initializes the hardware-specific context. If no initialization required,
|
| 218 |
+
/// returns the success status.
|
| 219 |
+
///
|
| 220 |
+
/// @param brgemm BRGeMM ukernel object.
|
| 221 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 222 |
+
/// otherwise.
|
| 223 |
+
dnnl_status_t DNNL_API dnnl_brgemm_set_hw_context(const_dnnl_brgemm_t brgemm);
|
| 224 |
+
|
| 225 |
+
/// Releases the hardware-specific context. Must be used after all the execution
|
| 226 |
+
/// calls to BRGeMM ukernel objects.
|
| 227 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 228 |
+
/// otherwise.
|
| 229 |
+
dnnl_status_t DNNL_API dnnl_brgemm_release_hw_context();
|
| 230 |
+
|
| 231 |
+
/// Generates an executable part of BRGeMM ukernel object.
|
| 232 |
+
/// @param brgemm BRGeMM ukernel object.
|
| 233 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 234 |
+
/// otherwise.
|
| 235 |
+
dnnl_status_t DNNL_API dnnl_brgemm_generate(dnnl_brgemm_t brgemm);
|
| 236 |
+
|
| 237 |
+
/// Executes a BRGeMM ukernel object.
|
| 238 |
+
///
|
| 239 |
+
/// @param brgemm BRGeMM ukernel object.
|
| 240 |
+
/// @param A_ptr Base pointer to a tensor A.
|
| 241 |
+
/// @param B_ptr Base pointer to a tensor B.
|
| 242 |
+
/// @param A_B_offsets Pointer to the set of tensor A and tensor B offsets for
|
| 243 |
+
/// each batch; the set must be contiguous in memory. Single batch should
|
| 244 |
+
/// supply offsets for both tensors A and B simultaneously. The number of
|
| 245 |
+
/// batches must coincide with the `batch_size` value passed at the creation
|
| 246 |
+
/// stage.
|
| 247 |
+
/// @param C_ptr Pointer to a tensor C (accumulation buffer).
|
| 248 |
+
/// @param scratchpad_ptr Pointer to a scratchpad buffer.
|
| 249 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 250 |
+
/// otherwise.
|
| 251 |
+
dnnl_status_t DNNL_API dnnl_brgemm_execute(const_dnnl_brgemm_t brgemm,
|
| 252 |
+
const void *A_ptr, const void *B_ptr, const dnnl_dim_t *A_B_offsets,
|
| 253 |
+
void *C_ptr, void *scratchpad_ptr);
|
| 254 |
+
|
| 255 |
+
/// Executes a BRGeMM ukernel object with post operations.
|
| 256 |
+
///
|
| 257 |
+
/// @param brgemm BRGeMM ukernel object.
|
| 258 |
+
/// @param A Base pointer to a tensor A.
|
| 259 |
+
/// @param B Base pointer to a tensor B.
|
| 260 |
+
/// @param A_B_offsets Pointer to a set of tensor A and tensor B offsets for
|
| 261 |
+
/// each batch. A set must be contiguous in memory. A single batch should
|
| 262 |
+
/// supply offsets for both tensors A and B simultaneously. The number of
|
| 263 |
+
/// batches must coincide with the `batch_size` value passed at the creation
|
| 264 |
+
/// stage.
|
| 265 |
+
/// @param C_ptr Pointer to a tensor C (accumulation buffer).
|
| 266 |
+
/// @param D_ptr Pointer to a tensor D (output buffer).
|
| 267 |
+
/// @param scratchpad_ptr Pointer to a scratchpad buffer.
|
| 268 |
+
/// @param attr_params Ukernel attributes memory storage.
|
| 269 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 270 |
+
/// otherwise.
|
| 271 |
+
dnnl_status_t DNNL_API dnnl_brgemm_execute_postops(const_dnnl_brgemm_t brgemm,
|
| 272 |
+
const void *A, const void *B, const dnnl_dim_t *A_B_offsets,
|
| 273 |
+
const void *C_ptr, void *D_ptr, void *scratchpad_ptr,
|
| 274 |
+
const_dnnl_ukernel_attr_params_t attr_params);
|
| 275 |
+
|
| 276 |
+
/// Destroys a BRGeMM ukernel object.
|
| 277 |
+
///
|
| 278 |
+
/// @param brgemm BRGeMM ukernel object to destroy.
|
| 279 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 280 |
+
/// otherwise.
|
| 281 |
+
dnnl_status_t DNNL_API dnnl_brgemm_destroy(dnnl_brgemm_t brgemm);
|
| 282 |
+
|
| 283 |
+
/// Creates a transform object.
|
| 284 |
+
///
|
| 285 |
+
/// @param transform Output transform object.
|
| 286 |
+
/// @param K Dimension K.
|
| 287 |
+
/// @param N Dimension N.
|
| 288 |
+
/// @param in_pack_type Input packing type. Must be one of
|
| 289 |
+
/// `dnnl_pack_type_no_trans`, or `dnnl_pack_type_trans`.
|
| 290 |
+
/// @param in_ld Input leading dimension.
|
| 291 |
+
/// @param out_ld Output leading dimension. When packing data, it specifies a
|
| 292 |
+
/// block by N dimension.
|
| 293 |
+
/// @param in_dt Input data type.
|
| 294 |
+
/// @param out_dt Output data type.
|
| 295 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 296 |
+
/// otherwise.
|
| 297 |
+
dnnl_status_t DNNL_API dnnl_transform_create(dnnl_transform_t *transform,
|
| 298 |
+
dnnl_dim_t K, dnnl_dim_t N, dnnl_pack_type_t in_pack_type,
|
| 299 |
+
dnnl_dim_t in_ld, dnnl_dim_t out_ld, dnnl_data_type_t in_dt,
|
| 300 |
+
dnnl_data_type_t out_dt);
|
| 301 |
+
|
| 302 |
+
/// Generates an executable part of transform object.
|
| 303 |
+
/// @param transform Transform object.
|
| 304 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 305 |
+
/// otherwise.
|
| 306 |
+
dnnl_status_t DNNL_API dnnl_transform_generate(dnnl_transform_t transform);
|
| 307 |
+
|
| 308 |
+
/// Executes a transform object.
|
| 309 |
+
///
|
| 310 |
+
/// @param transform Transform object.
|
| 311 |
+
/// @param in_ptr Pointer to an input buffer.
|
| 312 |
+
/// @param out_ptr Pointer to an output buffer.
|
| 313 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 314 |
+
/// otherwise.
|
| 315 |
+
dnnl_status_t DNNL_API dnnl_transform_execute(
|
| 316 |
+
const_dnnl_transform_t transform, const void *in_ptr, void *out_ptr);
|
| 317 |
+
|
| 318 |
+
/// Destroys a transform object.
|
| 319 |
+
///
|
| 320 |
+
/// @param transform Transform object.
|
| 321 |
+
/// @returns #dnnl_success on success and a status describing the error
|
| 322 |
+
/// otherwise.
|
| 323 |
+
dnnl_status_t DNNL_API dnnl_transform_destroy(dnnl_transform_t transform);
|
| 324 |
+
|
| 325 |
+
/// @} dnnl_api_ukernel_brgemm
|
| 326 |
+
|
| 327 |
+
#endif
|
| 328 |
+
|
| 329 |
+
/// @} dnnl_api_ukernel
|
| 330 |
+
|
| 331 |
+
/// @} dnnl_api
|
| 332 |
+
|
| 333 |
+
#ifdef __cplusplus
|
| 334 |
+
}
|
| 335 |
+
#endif
|
| 336 |
+
|
| 337 |
+
#endif /* ONEAPI_DNNL_DNNL_UKERNEL_H */
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel.hpp
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2024-2025 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
/// @file
|
| 18 |
+
/// ukernel C++ API
|
| 19 |
+
|
| 20 |
+
#ifndef ONEAPI_DNNL_DNNL_UKERNEL_HPP
|
| 21 |
+
#define ONEAPI_DNNL_DNNL_UKERNEL_HPP
|
| 22 |
+
|
| 23 |
+
#include "oneapi/dnnl/dnnl.hpp"
|
| 24 |
+
#include "oneapi/dnnl/dnnl_ukernel.h"
|
| 25 |
+
|
| 26 |
+
/// @addtogroup dnnl_api oneDNN API
|
| 27 |
+
/// @{
|
| 28 |
+
|
| 29 |
+
/// oneDNN namespace
|
| 30 |
+
namespace dnnl {
|
| 31 |
+
|
| 32 |
+
#ifdef DNNL_EXPERIMENTAL_UKERNEL
|
| 33 |
+
|
| 34 |
+
/// @addtogroup dnnl_api_utils
|
| 35 |
+
/// @{
|
| 36 |
+
|
| 37 |
+
/// @cond DO_NOT_DOCUMENT_THIS
|
| 38 |
+
|
| 39 |
+
template <>
|
| 40 |
+
struct handle_traits<dnnl_brgemm_t> {
|
| 41 |
+
static dnnl_status_t destructor(dnnl_brgemm_t p) {
|
| 42 |
+
return dnnl_brgemm_destroy(p);
|
| 43 |
+
}
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
template <>
|
| 47 |
+
struct handle_traits<dnnl_transform_t> {
|
| 48 |
+
static dnnl_status_t destructor(dnnl_transform_t p) {
|
| 49 |
+
return dnnl_transform_destroy(p);
|
| 50 |
+
}
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
template <>
|
| 54 |
+
struct handle_traits<dnnl_ukernel_attr_params_t> {
|
| 55 |
+
static dnnl_status_t destructor(dnnl_ukernel_attr_params_t p) {
|
| 56 |
+
return dnnl_ukernel_attr_params_destroy(p);
|
| 57 |
+
}
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
/// @endcond
|
| 61 |
+
|
| 62 |
+
/// @} dnnl_api_utils
|
| 63 |
+
|
| 64 |
+
#endif
|
| 65 |
+
|
| 66 |
+
/// @addtogroup dnnl_api_ukernel Ukernels
|
| 67 |
+
/// Collection of ukernels
|
| 68 |
+
/// @{
|
| 69 |
+
|
| 70 |
+
/// ukernel namespace
|
| 71 |
+
namespace ukernel {
|
| 72 |
+
|
| 73 |
+
#ifdef DNNL_EXPERIMENTAL_UKERNEL
|
| 74 |
+
|
| 75 |
+
/// @addtogroup dnnl_api_ukernel_utils ukernel utils
|
| 76 |
+
/// ukernel utility functions
|
| 77 |
+
/// @{
|
| 78 |
+
|
| 79 |
+
/// Packing specification
|
| 80 |
+
enum class pack_type {
|
| 81 |
+
/// Undefined pack type. A guard value.
|
| 82 |
+
undef = dnnl_pack_type_undef,
|
| 83 |
+
/// Plain, not transposed layout. Similar to format_tag::ab.
|
| 84 |
+
no_trans = dnnl_pack_type_no_trans,
|
| 85 |
+
/// Plain, transposed layout. Similar to format_tag::ba.
|
| 86 |
+
trans = dnnl_pack_type_trans,
|
| 87 |
+
/// Packed by 32 bits along K dimension layout.
|
| 88 |
+
pack32 = dnnl_pack_type_pack32,
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
/// Ukernel attributes memory storage
|
| 92 |
+
struct attr_params : public handle<dnnl_ukernel_attr_params_t> {
|
| 93 |
+
/// Constructs a ukernel attributes memory storage.
|
| 94 |
+
attr_params() {
|
| 95 |
+
dnnl_ukernel_attr_params_t c_params = nullptr;
|
| 96 |
+
dnnl_status_t status = dnnl_ukernel_attr_params_create(&c_params);
|
| 97 |
+
error::wrap_c_api(
|
| 98 |
+
status, "could not create an attributes memory storage");
|
| 99 |
+
reset(c_params);
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
/// Sets post-operations arguments to a storage.
|
| 103 |
+
///
|
| 104 |
+
/// @param post_ops_args Pointer to pointers of post_ops storages.
|
| 105 |
+
/// Expected to be packed together.
|
| 106 |
+
void set_post_ops_args(const void **post_ops_args) {
|
| 107 |
+
dnnl_status_t status = dnnl_ukernel_attr_params_set_post_ops_args(
|
| 108 |
+
get(), post_ops_args);
|
| 109 |
+
if (status != dnnl_success)
|
| 110 |
+
error::wrap_c_api(
|
| 111 |
+
status, "could not set post operations arguments");
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
/// Sets tensor A scales arguments to a storage.
|
| 115 |
+
///
|
| 116 |
+
/// @param a_scales Pointer to scales storage.
|
| 117 |
+
void set_A_scales(const void *a_scales) {
|
| 118 |
+
dnnl_status_t status
|
| 119 |
+
= dnnl_ukernel_attr_params_set_A_scales(get(), a_scales);
|
| 120 |
+
if (status != dnnl_success)
|
| 121 |
+
error::wrap_c_api(status, "could not set A scales argument");
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
/// Sets tensor B scales arguments to a storage.
|
| 125 |
+
///
|
| 126 |
+
/// If @ref attr_params::set_B_scales used mask of 2, then at
|
| 127 |
+
/// least N values of selected data type are expected.
|
| 128 |
+
///
|
| 129 |
+
/// @param b_scales Pointer to scales storage.
|
| 130 |
+
void set_B_scales(const void *b_scales) {
|
| 131 |
+
dnnl_status_t status
|
| 132 |
+
= dnnl_ukernel_attr_params_set_B_scales(get(), b_scales);
|
| 133 |
+
if (status != dnnl_success)
|
| 134 |
+
error::wrap_c_api(status, "could not set B scales argument");
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
/// Sets tensor D scales arguments to a storage.
|
| 138 |
+
///
|
| 139 |
+
/// @param d_scales Pointer to scales storage.
|
| 140 |
+
void set_D_scales(const void *d_scales) {
|
| 141 |
+
dnnl_status_t status
|
| 142 |
+
= dnnl_ukernel_attr_params_set_D_scales(get(), d_scales);
|
| 143 |
+
if (status != dnnl_success)
|
| 144 |
+
error::wrap_c_api(status, "could not set D scales argument");
|
| 145 |
+
}
|
| 146 |
+
};
|
| 147 |
+
/// @} dnnl_api_ukernel_utils
|
| 148 |
+
|
| 149 |
+
/// @addtogroup dnnl_api_ukernel_brgemm BRGeMM ukernel
|
| 150 |
+
/// BRGeMM ukernel routines
|
| 151 |
+
/// @{
|
| 152 |
+
|
| 153 |
+
/// BRGeMM ukernel
|
| 154 |
+
struct brgemm : public handle<dnnl_brgemm_t> {
|
| 155 |
+
/// Default constructor. Produces an empty object.
|
| 156 |
+
brgemm() = default;
|
| 157 |
+
|
| 158 |
+
/// Constructs a BRGeMM ukernel object. Operates by the following formula:
|
| 159 |
+
/// `C = [A x B]`.
|
| 160 |
+
///
|
| 161 |
+
/// @param M Dimension M of tensor A.
|
| 162 |
+
/// @param N Dimension N of tensor B.
|
| 163 |
+
/// @param K Dimension K of tensors A and B.
|
| 164 |
+
/// @param batch_size Number of batches to process.
|
| 165 |
+
/// @param lda Leading dimension of tensor A.
|
| 166 |
+
/// @param ldb Leading dimension of tensor B.
|
| 167 |
+
/// @param ldc Leading dimension of tensor C.
|
| 168 |
+
/// @param a_dt Data type of tensor A.
|
| 169 |
+
/// @param b_dt Data type of tensor B.
|
| 170 |
+
/// @param c_dt Data type of tensor C.
|
| 171 |
+
/// @param allow_empty A flag signifying whether construction is
|
| 172 |
+
/// allowed to fail without throwing an exception. In this case an
|
| 173 |
+
/// empty object will be produced. This flag is optional and
|
| 174 |
+
/// defaults to false.
|
| 175 |
+
brgemm(memory::dim M, memory::dim N, memory::dim K, memory::dim batch_size,
|
| 176 |
+
memory::dim lda, memory::dim ldb, memory::dim ldc,
|
| 177 |
+
memory::data_type a_dt, memory::data_type b_dt,
|
| 178 |
+
memory::data_type c_dt, bool allow_empty = false) {
|
| 179 |
+
|
| 180 |
+
dnnl_brgemm_t brgemm = nullptr;
|
| 181 |
+
dnnl_status_t status = dnnl_brgemm_create(&brgemm, M, N, K, batch_size,
|
| 182 |
+
lda, ldb, ldc, memory::convert_to_c(a_dt),
|
| 183 |
+
memory::convert_to_c(b_dt), memory::convert_to_c(c_dt));
|
| 184 |
+
|
| 185 |
+
if (!allow_empty)
|
| 186 |
+
error::wrap_c_api(
|
| 187 |
+
status, "could not create a BRGeMM ukernel object");
|
| 188 |
+
reset(brgemm);
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
/// Sets adding an intermediate result to the output tensor C instead of
|
| 192 |
+
/// writing: `C += [A x B]`.
|
| 193 |
+
///
|
| 194 |
+
/// @param add_C Value to indicate addition. `false` to skip addition, and
|
| 195 |
+
/// `true` to apply addition.
|
| 196 |
+
void set_add_C(bool add_C) {
|
| 197 |
+
dnnl_status_t status
|
| 198 |
+
= dnnl_brgemm_set_add_C(get(), static_cast<int>(add_C));
|
| 199 |
+
if (status != dnnl_success)
|
| 200 |
+
error::wrap_c_api(status, "could not set add_C attribute");
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
/// Sets post-operations to a BRGeMM ukernel object:
|
| 204 |
+
/// `D = post-operations(C)`.
|
| 205 |
+
///
|
| 206 |
+
/// Post-operations applies if one of the following holds:
|
| 207 |
+
/// * Non-empty post-operations are specified.
|
| 208 |
+
/// * Output data type `d_dt` is different from accumulation data type
|
| 209 |
+
/// `c_dt`.
|
| 210 |
+
///
|
| 211 |
+
/// @param ldd Leading dimension of tensor D.
|
| 212 |
+
/// @param d_dt Data type of tensor D.
|
| 213 |
+
/// @param po Primitive post-operation attributes to extend the kernel
|
| 214 |
+
/// operations.
|
| 215 |
+
void set_post_ops(memory::dim ldd, memory::data_type d_dt,
|
| 216 |
+
const post_ops &po = default_post_ops()) {
|
| 217 |
+
dnnl_status_t status = dnnl_brgemm_set_post_ops(
|
| 218 |
+
get(), ldd, memory::convert_to_c(d_dt), po.get());
|
| 219 |
+
if (status != dnnl_success)
|
| 220 |
+
error::wrap_c_api(status, "could not set post operations");
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
/// Sets tensor A scales mask to a BRGeMM ukernel object.
|
| 224 |
+
///
|
| 225 |
+
/// For quantization flavor tensor A scales apply to accumulation buffer
|
| 226 |
+
/// once C is ready.
|
| 227 |
+
///
|
| 228 |
+
/// @param a_scale_mask Tensor A scale mask. Can be `0` only.
|
| 229 |
+
void set_A_scales(int a_scale_mask) {
|
| 230 |
+
dnnl_status_t status = dnnl_brgemm_set_A_scales(get(), a_scale_mask);
|
| 231 |
+
if (status != dnnl_success)
|
| 232 |
+
error::wrap_c_api(status, "could not set A scales");
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
/// Sets tensor B scales mask to a BRGeMM ukernel object.
|
| 236 |
+
///
|
| 237 |
+
/// For quantization flavor tensor B scales apply to accumulation buffer
|
| 238 |
+
/// once C is ready.
|
| 239 |
+
///
|
| 240 |
+
/// @param b_scale_mask Tensor B scale mask. Can be `0` and `2` only.
|
| 241 |
+
void set_B_scales(int b_scale_mask) {
|
| 242 |
+
dnnl_status_t status = dnnl_brgemm_set_B_scales(get(), b_scale_mask);
|
| 243 |
+
if (status != dnnl_success)
|
| 244 |
+
error::wrap_c_api(status, "could not set B scales");
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
/// Sets tensor D scales mask to a BRGeMM ukernel object.
|
| 248 |
+
///
|
| 249 |
+
/// For quantization flavor tensor D scales apply after all post-ops are
|
| 250 |
+
/// applied.
|
| 251 |
+
///
|
| 252 |
+
/// @param d_scale_mask Tensor D scale mask. Can be `0` only.
|
| 253 |
+
void set_D_scales(int d_scale_mask) {
|
| 254 |
+
dnnl_status_t status = dnnl_brgemm_set_D_scales(get(), d_scale_mask);
|
| 255 |
+
if (status != dnnl_success)
|
| 256 |
+
error::wrap_c_api(status, "could not set D scales");
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
/// Finalizes initialization of a BRGeMM ukernel object.
|
| 260 |
+
///
|
| 261 |
+
/// This step must be performed prior to querying information from the
|
| 262 |
+
/// object.
|
| 263 |
+
void finalize() {
|
| 264 |
+
dnnl_status_t status = dnnl_brgemm_finalize(get());
|
| 265 |
+
if (status != dnnl_success)
|
| 266 |
+
error::wrap_c_api(status, "could not finalize an object");
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
/// Returns the packing type expected by a tensor B of a BRGeMM ukernel
|
| 270 |
+
/// object.
|
| 271 |
+
pack_type get_B_pack_type() const {
|
| 272 |
+
dnnl_pack_type_t c_pack_type;
|
| 273 |
+
dnnl_status_t status = dnnl_brgemm_get_B_pack_type(get(), &c_pack_type);
|
| 274 |
+
if (status != dnnl_success)
|
| 275 |
+
error::wrap_c_api(status, "could not query B pack type");
|
| 276 |
+
|
| 277 |
+
return static_cast<pack_type>(c_pack_type);
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
/// Returns the size of a scratchpad memory needed for the BRGeMM ukernel
|
| 281 |
+
/// object.
|
| 282 |
+
size_t get_scratchpad_size() const {
|
| 283 |
+
size_t size;
|
| 284 |
+
dnnl_status_t status = dnnl_brgemm_get_scratchpad_size(get(), &size);
|
| 285 |
+
if (status != dnnl_success)
|
| 286 |
+
error::wrap_c_api(status,
|
| 287 |
+
"could not query a scratchpad size from a BRGeMM ukernel "
|
| 288 |
+
"object");
|
| 289 |
+
return size;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
/// Returns the flag indicating when the call to execute with post
|
| 293 |
+
/// operations is valid.
|
| 294 |
+
///
|
| 295 |
+
/// `True` is for a valid call, `false`, otherwise.
|
| 296 |
+
bool is_execute_postops_valid() const {
|
| 297 |
+
int valid;
|
| 298 |
+
dnnl_status_t status
|
| 299 |
+
= dnnl_brgemm_is_execute_postops_valid(get(), &valid);
|
| 300 |
+
if (status != dnnl_success)
|
| 301 |
+
error::wrap_c_api(status,
|
| 302 |
+
"could not query a flag for execute postops from a BRGeMM "
|
| 303 |
+
"ukernel object");
|
| 304 |
+
return static_cast<bool>(valid);
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
/// Initializes the hardware-specific context. Affects the global state for
|
| 308 |
+
/// all BRGeMM ukernel objects. If no initialization required, returns.
|
| 309 |
+
void set_hw_context() const {
|
| 310 |
+
dnnl_status_t status = dnnl_brgemm_set_hw_context(get());
|
| 311 |
+
if (status != dnnl_success)
|
| 312 |
+
error::wrap_c_api(status, "could not set hardware context");
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
/// Releases the hardware-specific context. Affects the global state for
|
| 316 |
+
/// all BRGeMM ukernel objects. Must be used after all the execution calls
|
| 317 |
+
/// to BRGeMM ukernel objects.
|
| 318 |
+
static void release_hw_context() {
|
| 319 |
+
dnnl_status_t status = dnnl_brgemm_release_hw_context();
|
| 320 |
+
if (status != dnnl_success)
|
| 321 |
+
error::wrap_c_api(status, "could not release hardware context");
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
/// Generates an executable part of BRGeMM ukernel object.
|
| 325 |
+
void generate() {
|
| 326 |
+
dnnl_status_t status = dnnl_brgemm_generate(get());
|
| 327 |
+
if (status != dnnl_success)
|
| 328 |
+
error::wrap_c_api(status, "could not generate a kernel");
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
/// Executes a BRGeMM ukernel object.
|
| 332 |
+
///
|
| 333 |
+
/// @param A Base pointer to a tensor A.
|
| 334 |
+
/// @param B Base pointer to a tensor B.
|
| 335 |
+
/// @param A_B_offsets Vector of pairs of tensors A and B offsets for
|
| 336 |
+
/// each batch. The number of batches must coincide with the
|
| 337 |
+
/// `batch_size` value passed at object construction stage.
|
| 338 |
+
/// @param C Pointer to a tensor C (accumulation buffer).
|
| 339 |
+
/// @param scratchpad Pointer to a scratchpad buffer.
|
| 340 |
+
void execute(const void *A, const void *B,
|
| 341 |
+
const std::vector<std::pair<memory::dim, memory::dim>> &A_B_offsets,
|
| 342 |
+
void *C, void *scratchpad) const {
|
| 343 |
+
// TODO: export batch_element to C API later for user to fill it and
|
| 344 |
+
// pass directly to the call.
|
| 345 |
+
dnnl_status_t status = dnnl_brgemm_execute(get(), A, B,
|
| 346 |
+
(const dnnl_dim_t *)A_B_offsets.data(), C, scratchpad);
|
| 347 |
+
if (status != dnnl_success)
|
| 348 |
+
error::wrap_c_api(
|
| 349 |
+
status, "could not execute a BRGeMM ukernel object");
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
/// Executes a BRGeMM ukernel object with post operations.
|
| 353 |
+
///
|
| 354 |
+
/// @param A Base pointer to a tensor A.
|
| 355 |
+
/// @param B Base pointer to a tensor B.
|
| 356 |
+
/// @param A_B_offsets Vector of pairs of tensors A and B offsets for
|
| 357 |
+
/// each batch. The number of batches must coincide with the
|
| 358 |
+
/// `batch_size` value passed at object construction stage.
|
| 359 |
+
/// @param C Pointer to a tensor C (accumulation buffer).
|
| 360 |
+
/// @param D Pointer to a tensor D (output buffer).
|
| 361 |
+
/// @param scratchpad Pointer to a scratchpad buffer.
|
| 362 |
+
/// @param params Post-op memory arguments. Must be passed If binary
|
| 363 |
+
/// post-op or scales were set.
|
| 364 |
+
void execute(const void *A, const void *B,
|
| 365 |
+
const std::vector<std::pair<memory::dim, memory::dim>> &A_B_offsets,
|
| 366 |
+
const void *C, void *D, void *scratchpad,
|
| 367 |
+
const attr_params ¶ms = default_attr_params()) const {
|
| 368 |
+
// TODO: export batch_element to C API later for user to fill it and
|
| 369 |
+
// pass directly to the call.
|
| 370 |
+
dnnl_status_t status = dnnl_brgemm_execute_postops(get(), A, B,
|
| 371 |
+
(const dnnl_dim_t *)A_B_offsets.data(), C, D, scratchpad,
|
| 372 |
+
params.get());
|
| 373 |
+
if (status != dnnl_success)
|
| 374 |
+
error::wrap_c_api(
|
| 375 |
+
status, "could not execute a BRGeMM ukernel object");
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
/// Returns a constant reference to a static instance of default constructed
|
| 379 |
+
/// primitive post-operations attribute.
|
| 380 |
+
static const post_ops &default_post_ops() {
|
| 381 |
+
static const post_ops po;
|
| 382 |
+
return po;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
/// Returns a constant reference to a static instance of default constructed
|
| 386 |
+
/// ukernel attributes parameters.
|
| 387 |
+
static const attr_params &default_attr_params() {
|
| 388 |
+
static const attr_params ap;
|
| 389 |
+
return ap;
|
| 390 |
+
}
|
| 391 |
+
};
|
| 392 |
+
/// @} dnnl_api_ukernel_brgemm
|
| 393 |
+
|
| 394 |
+
/// @addtogroup dnnl_api_ukernel_transform Transform ukernel
|
| 395 |
+
/// Transform routines
|
| 396 |
+
/// @{
|
| 397 |
+
|
| 398 |
+
/// Transform ukernel
|
| 399 |
+
struct transform : public handle<dnnl_transform_t> {
|
| 400 |
+
/// Default constructor. Produces an empty object.
|
| 401 |
+
transform() = default;
|
| 402 |
+
|
| 403 |
+
/// Constructs a transform object.
|
| 404 |
+
///
|
| 405 |
+
/// @param K Dimension K.
|
| 406 |
+
/// @param N Dimension N.
|
| 407 |
+
/// @param in_pack_type Input packing type. Must be one of
|
| 408 |
+
/// `pack_type::no_trans`, or `pack_type::trans`.
|
| 409 |
+
/// @param in_ld Input leading dimension.
|
| 410 |
+
/// @param out_ld Output leading dimension. Specifies a block by N dimension
|
| 411 |
+
/// during data packing.
|
| 412 |
+
/// @param in_dt Input data type.
|
| 413 |
+
/// @param out_dt Output data type.
|
| 414 |
+
/// @param allow_empty A flag signifying whether construction is
|
| 415 |
+
/// allowed to fail without throwing an exception. In this case an
|
| 416 |
+
/// empty object will be produced. This flag is optional and
|
| 417 |
+
/// defaults to false.
|
| 418 |
+
transform(memory::dim K, memory::dim N, pack_type in_pack_type,
|
| 419 |
+
memory::dim in_ld, memory::dim out_ld, memory::data_type in_dt,
|
| 420 |
+
memory::data_type out_dt, bool allow_empty = false) {
|
| 421 |
+
|
| 422 |
+
dnnl_transform_t transform = nullptr;
|
| 423 |
+
dnnl_status_t status = dnnl_transform_create(&transform, K, N,
|
| 424 |
+
static_cast<dnnl_pack_type_t>(in_pack_type), in_ld, out_ld,
|
| 425 |
+
memory::convert_to_c(in_dt), memory::convert_to_c(out_dt));
|
| 426 |
+
|
| 427 |
+
if (!allow_empty)
|
| 428 |
+
error::wrap_c_api(status,
|
| 429 |
+
"could not create a BRGeMM ukernel packing B object");
|
| 430 |
+
reset(transform);
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
/// Generates an executable part of transform object.
|
| 434 |
+
void generate() {
|
| 435 |
+
dnnl_status_t status = dnnl_transform_generate(get());
|
| 436 |
+
if (status != dnnl_success)
|
| 437 |
+
error::wrap_c_api(status,
|
| 438 |
+
"could not generate a BRGeMM ukernel packing B object");
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
/// Executes a transform object.
|
| 442 |
+
///
|
| 443 |
+
/// @param in Pointer to an input buffer.
|
| 444 |
+
/// @param out Pointer to an output buffer.
|
| 445 |
+
void execute(const void *in, void *out) const {
|
| 446 |
+
dnnl_status_t status = dnnl_transform_execute(get(), in, out);
|
| 447 |
+
if (status != dnnl_success)
|
| 448 |
+
error::wrap_c_api(status,
|
| 449 |
+
"could not execute a BRGeMM ukernel packing B object");
|
| 450 |
+
}
|
| 451 |
+
};
|
| 452 |
+
|
| 453 |
+
/// @} dnnl_api_ukernel_transform
|
| 454 |
+
|
| 455 |
+
#endif
|
| 456 |
+
|
| 457 |
+
} // namespace ukernel
|
| 458 |
+
|
| 459 |
+
/// @} dnnl_api_ukernel
|
| 460 |
+
|
| 461 |
+
} // namespace dnnl
|
| 462 |
+
|
| 463 |
+
/// @} dnnl_api
|
| 464 |
+
|
| 465 |
+
#endif /* ONEAPI_DNNL_DNNL_UKERNEL_HPP */
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_ukernel_types.h
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2024 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
/// @file
|
| 18 |
+
/// ukernel C API types definitions
|
| 19 |
+
|
| 20 |
+
#ifndef ONEAPI_DNNL_DNNL_UKERNEL_TYPES_H
|
| 21 |
+
#define ONEAPI_DNNL_DNNL_UKERNEL_TYPES_H
|
| 22 |
+
|
| 23 |
+
#ifdef __cplusplus
|
| 24 |
+
extern "C" {
|
| 25 |
+
#endif
|
| 26 |
+
|
| 27 |
+
#include "oneapi/dnnl/dnnl_types.h"
|
| 28 |
+
|
| 29 |
+
/// @addtogroup dnnl_api
|
| 30 |
+
/// @{
|
| 31 |
+
|
| 32 |
+
/// @addtogroup dnnl_api_ukernel
|
| 33 |
+
/// @{
|
| 34 |
+
|
| 35 |
+
#ifdef DNNL_EXPERIMENTAL_UKERNEL
|
| 36 |
+
|
| 37 |
+
/// Packing specification
|
| 38 |
+
typedef enum {
|
| 39 |
+
/// Undefined pack type. A guard value.
|
| 40 |
+
dnnl_pack_type_undef = 0,
|
| 41 |
+
/// Plain, not transposed layout. Similar to format_tag::ab.
|
| 42 |
+
dnnl_pack_type_no_trans,
|
| 43 |
+
/// Plain, transposed layout. Similar to format_tag::ba.
|
| 44 |
+
dnnl_pack_type_trans,
|
| 45 |
+
/// Packed by 32 bits along K dimension layout.
|
| 46 |
+
dnnl_pack_type_pack32,
|
| 47 |
+
} dnnl_pack_type_t;
|
| 48 |
+
|
| 49 |
+
/// @struct dnnl_ukernel_attr_params
|
| 50 |
+
/// An opaque structure to describe ukernel attributes memory storage.
|
| 51 |
+
struct dnnl_ukernel_attr_params;
|
| 52 |
+
|
| 53 |
+
/// A ukernel attributes memory storage handle.
|
| 54 |
+
typedef struct dnnl_ukernel_attr_params *dnnl_ukernel_attr_params_t;
|
| 55 |
+
|
| 56 |
+
/// A constant ukernel attributes memory storage handle.
|
| 57 |
+
typedef const struct dnnl_ukernel_attr_params *const_dnnl_ukernel_attr_params_t;
|
| 58 |
+
|
| 59 |
+
/// @addtogroup dnnl_api_ukernel_brgemm
|
| 60 |
+
/// @{
|
| 61 |
+
|
| 62 |
+
/// @struct dnnl_brgemm
|
| 63 |
+
/// An opaque structure to describe a brgemm ukernel.
|
| 64 |
+
struct dnnl_brgemm;
|
| 65 |
+
|
| 66 |
+
/// A brgemm ukernel handle.
|
| 67 |
+
typedef struct dnnl_brgemm *dnnl_brgemm_t;
|
| 68 |
+
|
| 69 |
+
/// A constant brgemm ukernel handle.
|
| 70 |
+
typedef const struct dnnl_brgemm *const_dnnl_brgemm_t;
|
| 71 |
+
|
| 72 |
+
/// @struct dnnl_transform
|
| 73 |
+
/// An opaque structure to describe a transform routine.
|
| 74 |
+
struct dnnl_transform;
|
| 75 |
+
|
| 76 |
+
/// A transform routine handle.
|
| 77 |
+
typedef struct dnnl_transform *dnnl_transform_t;
|
| 78 |
+
|
| 79 |
+
/// A constant transform routine handle.
|
| 80 |
+
typedef const struct dnnl_transform *const_dnnl_transform_t;
|
| 81 |
+
|
| 82 |
+
/// @} dnnl_api_ukernel_brgemm
|
| 83 |
+
#endif
|
| 84 |
+
|
| 85 |
+
/// @} dnnl_api_ukernel
|
| 86 |
+
|
| 87 |
+
/// @} dnnl_api
|
| 88 |
+
|
| 89 |
+
#ifdef __cplusplus
|
| 90 |
+
}
|
| 91 |
+
#endif
|
| 92 |
+
|
| 93 |
+
#endif /* ONEAPI_DNNL_DNNL_UKERNEL_TYPES_H */
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_version.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2019-2024 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
#ifndef ONEAPI_DNNL_DNNL_VERSION_H
|
| 18 |
+
#define ONEAPI_DNNL_DNNL_VERSION_H
|
| 19 |
+
|
| 20 |
+
// clang-format off
|
| 21 |
+
|
| 22 |
+
/// Major version
|
| 23 |
+
#define DNNL_VERSION_MAJOR 3
|
| 24 |
+
|
| 25 |
+
/// Minor version
|
| 26 |
+
#define DNNL_VERSION_MINOR 7
|
| 27 |
+
|
| 28 |
+
/// Patch version
|
| 29 |
+
#define DNNL_VERSION_PATCH 1
|
| 30 |
+
|
| 31 |
+
// clang-format on
|
| 32 |
+
|
| 33 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/oneapi/dnnl/dnnl_version_hash.h
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*******************************************************************************
|
| 2 |
+
* Copyright 2024 Intel Corporation
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*******************************************************************************/
|
| 16 |
+
|
| 17 |
+
#ifndef ONEAPI_DNNL_DNNL_VERSION_HASH_H
|
| 18 |
+
#define ONEAPI_DNNL_DNNL_VERSION_HASH_H
|
| 19 |
+
|
| 20 |
+
// clang-format off
|
| 21 |
+
|
| 22 |
+
/// Note: this macro and header file were moved to a separate instance to avoid
|
| 23 |
+
/// incremental build issues as moving from commit to commit would trigger a
|
| 24 |
+
/// complete library rebuild. Including a generated header file in a single
|
| 25 |
+
/// translation unit makes this problem go away.
|
| 26 |
+
/// Git commit hash
|
| 27 |
+
#define DNNL_VERSION_HASH "8d263e693366ef8db40acc569cc7d8edf644556d"
|
| 28 |
+
|
| 29 |
+
// clang-format on
|
| 30 |
+
|
| 31 |
+
#endif
|
phivenv/Lib/site-packages/torch/include/pybind11/attr.h
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/attr.h: Infrastructure for processing custom
|
| 3 |
+
type and function attributes
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 6 |
+
|
| 7 |
+
All rights reserved. Use of this source code is governed by a
|
| 8 |
+
BSD-style license that can be found in the LICENSE file.
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include "detail/common.h"
|
| 14 |
+
#include "cast.h"
|
| 15 |
+
|
| 16 |
+
#include <functional>
|
| 17 |
+
|
| 18 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 19 |
+
|
| 20 |
+
/// \addtogroup annotations
|
| 21 |
+
/// @{
|
| 22 |
+
|
| 23 |
+
/// Annotation for methods
|
| 24 |
+
struct is_method {
|
| 25 |
+
handle class_;
|
| 26 |
+
explicit is_method(const handle &c) : class_(c) {}
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
/// Annotation for setters
|
| 30 |
+
struct is_setter {};
|
| 31 |
+
|
| 32 |
+
/// Annotation for operators
|
| 33 |
+
struct is_operator {};
|
| 34 |
+
|
| 35 |
+
/// Annotation for classes that cannot be subclassed
|
| 36 |
+
struct is_final {};
|
| 37 |
+
|
| 38 |
+
/// Annotation for parent scope
|
| 39 |
+
struct scope {
|
| 40 |
+
handle value;
|
| 41 |
+
explicit scope(const handle &s) : value(s) {}
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
/// Annotation for documentation
|
| 45 |
+
struct doc {
|
| 46 |
+
const char *value;
|
| 47 |
+
explicit doc(const char *value) : value(value) {}
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
/// Annotation for function names
|
| 51 |
+
struct name {
|
| 52 |
+
const char *value;
|
| 53 |
+
explicit name(const char *value) : value(value) {}
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
/// Annotation indicating that a function is an overload associated with a given "sibling"
|
| 57 |
+
struct sibling {
|
| 58 |
+
handle value;
|
| 59 |
+
explicit sibling(const handle &value) : value(value.ptr()) {}
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
/// Annotation indicating that a class derives from another given type
|
| 63 |
+
template <typename T>
|
| 64 |
+
struct base {
|
| 65 |
+
|
| 66 |
+
PYBIND11_DEPRECATED(
|
| 67 |
+
"base<T>() was deprecated in favor of specifying 'T' as a template argument to class_")
|
| 68 |
+
base() = default;
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
/// Keep patient alive while nurse lives
|
| 72 |
+
template <size_t Nurse, size_t Patient>
|
| 73 |
+
struct keep_alive {};
|
| 74 |
+
|
| 75 |
+
/// Annotation indicating that a class is involved in a multiple inheritance relationship
|
| 76 |
+
struct multiple_inheritance {};
|
| 77 |
+
|
| 78 |
+
/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class
|
| 79 |
+
struct dynamic_attr {};
|
| 80 |
+
|
| 81 |
+
/// Annotation which enables the buffer protocol for a type
|
| 82 |
+
struct buffer_protocol {};
|
| 83 |
+
|
| 84 |
+
/// Annotation which requests that a special metaclass is created for a type
|
| 85 |
+
struct metaclass {
|
| 86 |
+
handle value;
|
| 87 |
+
|
| 88 |
+
PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.")
|
| 89 |
+
metaclass() = default;
|
| 90 |
+
|
| 91 |
+
/// Override pybind11's default metaclass
|
| 92 |
+
explicit metaclass(handle value) : value(value) {}
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
/// Specifies a custom callback with signature `void (PyHeapTypeObject*)` that
|
| 96 |
+
/// may be used to customize the Python type.
|
| 97 |
+
///
|
| 98 |
+
/// The callback is invoked immediately before `PyType_Ready`.
|
| 99 |
+
///
|
| 100 |
+
/// Note: This is an advanced interface, and uses of it may require changes to
|
| 101 |
+
/// work with later versions of pybind11. You may wish to consult the
|
| 102 |
+
/// implementation of `make_new_python_type` in `detail/classes.h` to understand
|
| 103 |
+
/// the context in which the callback will be run.
|
| 104 |
+
struct custom_type_setup {
|
| 105 |
+
using callback = std::function<void(PyHeapTypeObject *heap_type)>;
|
| 106 |
+
|
| 107 |
+
explicit custom_type_setup(callback value) : value(std::move(value)) {}
|
| 108 |
+
|
| 109 |
+
callback value;
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
/// Annotation that marks a class as local to the module:
|
| 113 |
+
struct module_local {
|
| 114 |
+
const bool value;
|
| 115 |
+
constexpr explicit module_local(bool v = true) : value(v) {}
|
| 116 |
+
};
|
| 117 |
+
|
| 118 |
+
/// Annotation to mark enums as an arithmetic type
|
| 119 |
+
struct arithmetic {};
|
| 120 |
+
|
| 121 |
+
/// Mark a function for addition at the beginning of the existing overload chain instead of the end
|
| 122 |
+
struct prepend {};
|
| 123 |
+
|
| 124 |
+
/** \rst
|
| 125 |
+
A call policy which places one or more guard variables (``Ts...``) around the function call.
|
| 126 |
+
|
| 127 |
+
For example, this definition:
|
| 128 |
+
|
| 129 |
+
.. code-block:: cpp
|
| 130 |
+
|
| 131 |
+
m.def("foo", foo, py::call_guard<T>());
|
| 132 |
+
|
| 133 |
+
is equivalent to the following pseudocode:
|
| 134 |
+
|
| 135 |
+
.. code-block:: cpp
|
| 136 |
+
|
| 137 |
+
m.def("foo", [](args...) {
|
| 138 |
+
T scope_guard;
|
| 139 |
+
return foo(args...); // forwarded arguments
|
| 140 |
+
});
|
| 141 |
+
\endrst */
|
| 142 |
+
template <typename... Ts>
|
| 143 |
+
struct call_guard;
|
| 144 |
+
|
| 145 |
+
template <>
|
| 146 |
+
struct call_guard<> {
|
| 147 |
+
using type = detail::void_type;
|
| 148 |
+
};
|
| 149 |
+
|
| 150 |
+
template <typename T>
|
| 151 |
+
struct call_guard<T> {
|
| 152 |
+
static_assert(std::is_default_constructible<T>::value,
|
| 153 |
+
"The guard type must be default constructible");
|
| 154 |
+
|
| 155 |
+
using type = T;
|
| 156 |
+
};
|
| 157 |
+
|
| 158 |
+
template <typename T, typename... Ts>
|
| 159 |
+
struct call_guard<T, Ts...> {
|
| 160 |
+
struct type {
|
| 161 |
+
T guard{}; // Compose multiple guard types with left-to-right default-constructor order
|
| 162 |
+
typename call_guard<Ts...>::type next{};
|
| 163 |
+
};
|
| 164 |
+
};
|
| 165 |
+
|
| 166 |
+
/// @} annotations
|
| 167 |
+
|
| 168 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 169 |
+
/* Forward declarations */
|
| 170 |
+
enum op_id : int;
|
| 171 |
+
enum op_type : int;
|
| 172 |
+
struct undefined_t;
|
| 173 |
+
template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined_t>
|
| 174 |
+
struct op_;
|
| 175 |
+
void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret);
|
| 176 |
+
|
| 177 |
+
/// Internal data structure which holds metadata about a keyword argument
|
| 178 |
+
struct argument_record {
|
| 179 |
+
const char *name; ///< Argument name
|
| 180 |
+
const char *descr; ///< Human-readable version of the argument value
|
| 181 |
+
handle value; ///< Associated Python object
|
| 182 |
+
bool convert : 1; ///< True if the argument is allowed to convert when loading
|
| 183 |
+
bool none : 1; ///< True if None is allowed when loading
|
| 184 |
+
|
| 185 |
+
argument_record(const char *name, const char *descr, handle value, bool convert, bool none)
|
| 186 |
+
: name(name), descr(descr), value(value), convert(convert), none(none) {}
|
| 187 |
+
};
|
| 188 |
+
|
| 189 |
+
/// Internal data structure which holds metadata about a bound function (signature, overloads,
|
| 190 |
+
/// etc.)
|
| 191 |
+
struct function_record {
|
| 192 |
+
function_record()
|
| 193 |
+
: is_constructor(false), is_new_style_constructor(false), is_stateless(false),
|
| 194 |
+
is_operator(false), is_method(false), is_setter(false), has_args(false),
|
| 195 |
+
has_kwargs(false), prepend(false) {}
|
| 196 |
+
|
| 197 |
+
/// Function name
|
| 198 |
+
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
|
| 199 |
+
|
| 200 |
+
// User-specified documentation string
|
| 201 |
+
char *doc = nullptr;
|
| 202 |
+
|
| 203 |
+
/// Human-readable version of the function signature
|
| 204 |
+
char *signature = nullptr;
|
| 205 |
+
|
| 206 |
+
/// List of registered keyword arguments
|
| 207 |
+
std::vector<argument_record> args;
|
| 208 |
+
|
| 209 |
+
/// Pointer to lambda function which converts arguments and performs the actual call
|
| 210 |
+
handle (*impl)(function_call &) = nullptr;
|
| 211 |
+
|
| 212 |
+
/// Storage for the wrapped function pointer and captured data, if any
|
| 213 |
+
void *data[3] = {};
|
| 214 |
+
|
| 215 |
+
/// Pointer to custom destructor for 'data' (if needed)
|
| 216 |
+
void (*free_data)(function_record *ptr) = nullptr;
|
| 217 |
+
|
| 218 |
+
/// Return value policy associated with this function
|
| 219 |
+
return_value_policy policy = return_value_policy::automatic;
|
| 220 |
+
|
| 221 |
+
/// True if name == '__init__'
|
| 222 |
+
bool is_constructor : 1;
|
| 223 |
+
|
| 224 |
+
/// True if this is a new-style `__init__` defined in `detail/init.h`
|
| 225 |
+
bool is_new_style_constructor : 1;
|
| 226 |
+
|
| 227 |
+
/// True if this is a stateless function pointer
|
| 228 |
+
bool is_stateless : 1;
|
| 229 |
+
|
| 230 |
+
/// True if this is an operator (__add__), etc.
|
| 231 |
+
bool is_operator : 1;
|
| 232 |
+
|
| 233 |
+
/// True if this is a method
|
| 234 |
+
bool is_method : 1;
|
| 235 |
+
|
| 236 |
+
/// True if this is a setter
|
| 237 |
+
bool is_setter : 1;
|
| 238 |
+
|
| 239 |
+
/// True if the function has a '*args' argument
|
| 240 |
+
bool has_args : 1;
|
| 241 |
+
|
| 242 |
+
/// True if the function has a '**kwargs' argument
|
| 243 |
+
bool has_kwargs : 1;
|
| 244 |
+
|
| 245 |
+
/// True if this function is to be inserted at the beginning of the overload resolution chain
|
| 246 |
+
bool prepend : 1;
|
| 247 |
+
|
| 248 |
+
/// Number of arguments (including py::args and/or py::kwargs, if present)
|
| 249 |
+
std::uint16_t nargs;
|
| 250 |
+
|
| 251 |
+
/// Number of leading positional arguments, which are terminated by a py::args or py::kwargs
|
| 252 |
+
/// argument or by a py::kw_only annotation.
|
| 253 |
+
std::uint16_t nargs_pos = 0;
|
| 254 |
+
|
| 255 |
+
/// Number of leading arguments (counted in `nargs`) that are positional-only
|
| 256 |
+
std::uint16_t nargs_pos_only = 0;
|
| 257 |
+
|
| 258 |
+
/// Python method object
|
| 259 |
+
PyMethodDef *def = nullptr;
|
| 260 |
+
|
| 261 |
+
/// Python handle to the parent scope (a class or a module)
|
| 262 |
+
handle scope;
|
| 263 |
+
|
| 264 |
+
/// Python handle to the sibling function representing an overload chain
|
| 265 |
+
handle sibling;
|
| 266 |
+
|
| 267 |
+
/// Pointer to next overload
|
| 268 |
+
function_record *next = nullptr;
|
| 269 |
+
};
|
| 270 |
+
|
| 271 |
+
/// Special data structure which (temporarily) holds metadata about a bound class
|
| 272 |
+
struct type_record {
|
| 273 |
+
PYBIND11_NOINLINE type_record()
|
| 274 |
+
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false),
|
| 275 |
+
default_holder(true), module_local(false), is_final(false) {}
|
| 276 |
+
|
| 277 |
+
/// Handle to the parent scope
|
| 278 |
+
handle scope;
|
| 279 |
+
|
| 280 |
+
/// Name of the class
|
| 281 |
+
const char *name = nullptr;
|
| 282 |
+
|
| 283 |
+
// Pointer to RTTI type_info data structure
|
| 284 |
+
const std::type_info *type = nullptr;
|
| 285 |
+
|
| 286 |
+
/// How large is the underlying C++ type?
|
| 287 |
+
size_t type_size = 0;
|
| 288 |
+
|
| 289 |
+
/// What is the alignment of the underlying C++ type?
|
| 290 |
+
size_t type_align = 0;
|
| 291 |
+
|
| 292 |
+
/// How large is the type's holder?
|
| 293 |
+
size_t holder_size = 0;
|
| 294 |
+
|
| 295 |
+
/// The global operator new can be overridden with a class-specific variant
|
| 296 |
+
void *(*operator_new)(size_t) = nullptr;
|
| 297 |
+
|
| 298 |
+
/// Function pointer to class_<..>::init_instance
|
| 299 |
+
void (*init_instance)(instance *, const void *) = nullptr;
|
| 300 |
+
|
| 301 |
+
/// Function pointer to class_<..>::dealloc
|
| 302 |
+
void (*dealloc)(detail::value_and_holder &) = nullptr;
|
| 303 |
+
|
| 304 |
+
/// List of base classes of the newly created type
|
| 305 |
+
list bases;
|
| 306 |
+
|
| 307 |
+
/// Optional docstring
|
| 308 |
+
const char *doc = nullptr;
|
| 309 |
+
|
| 310 |
+
/// Custom metaclass (optional)
|
| 311 |
+
handle metaclass;
|
| 312 |
+
|
| 313 |
+
/// Custom type setup.
|
| 314 |
+
custom_type_setup::callback custom_type_setup_callback;
|
| 315 |
+
|
| 316 |
+
/// Multiple inheritance marker
|
| 317 |
+
bool multiple_inheritance : 1;
|
| 318 |
+
|
| 319 |
+
/// Does the class manage a __dict__?
|
| 320 |
+
bool dynamic_attr : 1;
|
| 321 |
+
|
| 322 |
+
/// Does the class implement the buffer protocol?
|
| 323 |
+
bool buffer_protocol : 1;
|
| 324 |
+
|
| 325 |
+
/// Is the default (unique_ptr) holder type used?
|
| 326 |
+
bool default_holder : 1;
|
| 327 |
+
|
| 328 |
+
/// Is the class definition local to the module shared object?
|
| 329 |
+
bool module_local : 1;
|
| 330 |
+
|
| 331 |
+
/// Is the class inheritable from python classes?
|
| 332 |
+
bool is_final : 1;
|
| 333 |
+
|
| 334 |
+
PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *) ) {
|
| 335 |
+
auto *base_info = detail::get_type_info(base, false);
|
| 336 |
+
if (!base_info) {
|
| 337 |
+
std::string tname(base.name());
|
| 338 |
+
detail::clean_type_id(tname);
|
| 339 |
+
pybind11_fail("generic_type: type \"" + std::string(name)
|
| 340 |
+
+ "\" referenced unknown base type \"" + tname + "\"");
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
if (default_holder != base_info->default_holder) {
|
| 344 |
+
std::string tname(base.name());
|
| 345 |
+
detail::clean_type_id(tname);
|
| 346 |
+
pybind11_fail("generic_type: type \"" + std::string(name) + "\" "
|
| 347 |
+
+ (default_holder ? "does not have" : "has")
|
| 348 |
+
+ " a non-default holder type while its base \"" + tname + "\" "
|
| 349 |
+
+ (base_info->default_holder ? "does not" : "does"));
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
bases.append((PyObject *) base_info->type);
|
| 353 |
+
|
| 354 |
+
#if PY_VERSION_HEX < 0x030B0000
|
| 355 |
+
dynamic_attr |= base_info->type->tp_dictoffset != 0;
|
| 356 |
+
#else
|
| 357 |
+
dynamic_attr |= (base_info->type->tp_flags & Py_TPFLAGS_MANAGED_DICT) != 0;
|
| 358 |
+
#endif
|
| 359 |
+
|
| 360 |
+
if (caster) {
|
| 361 |
+
base_info->implicit_casts.emplace_back(type, caster);
|
| 362 |
+
}
|
| 363 |
+
}
|
| 364 |
+
};
|
| 365 |
+
|
| 366 |
+
inline function_call::function_call(const function_record &f, handle p) : func(f), parent(p) {
|
| 367 |
+
args.reserve(f.nargs);
|
| 368 |
+
args_convert.reserve(f.nargs);
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
/// Tag for a new-style `__init__` defined in `detail/init.h`
|
| 372 |
+
struct is_new_style_constructor {};
|
| 373 |
+
|
| 374 |
+
/**
|
| 375 |
+
* Partial template specializations to process custom attributes provided to
|
| 376 |
+
* cpp_function_ and class_. These are either used to initialize the respective
|
| 377 |
+
* fields in the type_record and function_record data structures or executed at
|
| 378 |
+
* runtime to deal with custom call policies (e.g. keep_alive).
|
| 379 |
+
*/
|
| 380 |
+
template <typename T, typename SFINAE = void>
|
| 381 |
+
struct process_attribute;
|
| 382 |
+
|
| 383 |
+
template <typename T>
|
| 384 |
+
struct process_attribute_default {
|
| 385 |
+
/// Default implementation: do nothing
|
| 386 |
+
static void init(const T &, function_record *) {}
|
| 387 |
+
static void init(const T &, type_record *) {}
|
| 388 |
+
static void precall(function_call &) {}
|
| 389 |
+
static void postcall(function_call &, handle) {}
|
| 390 |
+
};
|
| 391 |
+
|
| 392 |
+
/// Process an attribute specifying the function's name
|
| 393 |
+
template <>
|
| 394 |
+
struct process_attribute<name> : process_attribute_default<name> {
|
| 395 |
+
static void init(const name &n, function_record *r) { r->name = const_cast<char *>(n.value); }
|
| 396 |
+
};
|
| 397 |
+
|
| 398 |
+
/// Process an attribute specifying the function's docstring
|
| 399 |
+
template <>
|
| 400 |
+
struct process_attribute<doc> : process_attribute_default<doc> {
|
| 401 |
+
static void init(const doc &n, function_record *r) { r->doc = const_cast<char *>(n.value); }
|
| 402 |
+
};
|
| 403 |
+
|
| 404 |
+
/// Process an attribute specifying the function's docstring (provided as a C-style string)
|
| 405 |
+
template <>
|
| 406 |
+
struct process_attribute<const char *> : process_attribute_default<const char *> {
|
| 407 |
+
static void init(const char *d, function_record *r) { r->doc = const_cast<char *>(d); }
|
| 408 |
+
static void init(const char *d, type_record *r) { r->doc = d; }
|
| 409 |
+
};
|
| 410 |
+
template <>
|
| 411 |
+
struct process_attribute<char *> : process_attribute<const char *> {};
|
| 412 |
+
|
| 413 |
+
/// Process an attribute indicating the function's return value policy
|
| 414 |
+
template <>
|
| 415 |
+
struct process_attribute<return_value_policy> : process_attribute_default<return_value_policy> {
|
| 416 |
+
static void init(const return_value_policy &p, function_record *r) { r->policy = p; }
|
| 417 |
+
};
|
| 418 |
+
|
| 419 |
+
/// Process an attribute which indicates that this is an overloaded function associated with a
|
| 420 |
+
/// given sibling
|
| 421 |
+
template <>
|
| 422 |
+
struct process_attribute<sibling> : process_attribute_default<sibling> {
|
| 423 |
+
static void init(const sibling &s, function_record *r) { r->sibling = s.value; }
|
| 424 |
+
};
|
| 425 |
+
|
| 426 |
+
/// Process an attribute which indicates that this function is a method
|
| 427 |
+
template <>
|
| 428 |
+
struct process_attribute<is_method> : process_attribute_default<is_method> {
|
| 429 |
+
static void init(const is_method &s, function_record *r) {
|
| 430 |
+
r->is_method = true;
|
| 431 |
+
r->scope = s.class_;
|
| 432 |
+
}
|
| 433 |
+
};
|
| 434 |
+
|
| 435 |
+
/// Process an attribute which indicates that this function is a setter
|
| 436 |
+
template <>
|
| 437 |
+
struct process_attribute<is_setter> : process_attribute_default<is_setter> {
|
| 438 |
+
static void init(const is_setter &, function_record *r) { r->is_setter = true; }
|
| 439 |
+
};
|
| 440 |
+
|
| 441 |
+
/// Process an attribute which indicates the parent scope of a method
|
| 442 |
+
template <>
|
| 443 |
+
struct process_attribute<scope> : process_attribute_default<scope> {
|
| 444 |
+
static void init(const scope &s, function_record *r) { r->scope = s.value; }
|
| 445 |
+
};
|
| 446 |
+
|
| 447 |
+
/// Process an attribute which indicates that this function is an operator
|
| 448 |
+
template <>
|
| 449 |
+
struct process_attribute<is_operator> : process_attribute_default<is_operator> {
|
| 450 |
+
static void init(const is_operator &, function_record *r) { r->is_operator = true; }
|
| 451 |
+
};
|
| 452 |
+
|
| 453 |
+
template <>
|
| 454 |
+
struct process_attribute<is_new_style_constructor>
|
| 455 |
+
: process_attribute_default<is_new_style_constructor> {
|
| 456 |
+
static void init(const is_new_style_constructor &, function_record *r) {
|
| 457 |
+
r->is_new_style_constructor = true;
|
| 458 |
+
}
|
| 459 |
+
};
|
| 460 |
+
|
| 461 |
+
inline void check_kw_only_arg(const arg &a, function_record *r) {
|
| 462 |
+
if (r->args.size() > r->nargs_pos && (!a.name || a.name[0] == '\0')) {
|
| 463 |
+
pybind11_fail("arg(): cannot specify an unnamed argument after a kw_only() annotation or "
|
| 464 |
+
"args() argument");
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
inline void append_self_arg_if_needed(function_record *r) {
|
| 469 |
+
if (r->is_method && r->args.empty()) {
|
| 470 |
+
r->args.emplace_back("self", nullptr, handle(), /*convert=*/true, /*none=*/false);
|
| 471 |
+
}
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
/// Process a keyword argument attribute (*without* a default value)
|
| 475 |
+
template <>
|
| 476 |
+
struct process_attribute<arg> : process_attribute_default<arg> {
|
| 477 |
+
static void init(const arg &a, function_record *r) {
|
| 478 |
+
append_self_arg_if_needed(r);
|
| 479 |
+
r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none);
|
| 480 |
+
|
| 481 |
+
check_kw_only_arg(a, r);
|
| 482 |
+
}
|
| 483 |
+
};
|
| 484 |
+
|
| 485 |
+
/// Process a keyword argument attribute (*with* a default value)
|
| 486 |
+
template <>
|
| 487 |
+
struct process_attribute<arg_v> : process_attribute_default<arg_v> {
|
| 488 |
+
static void init(const arg_v &a, function_record *r) {
|
| 489 |
+
if (r->is_method && r->args.empty()) {
|
| 490 |
+
r->args.emplace_back(
|
| 491 |
+
"self", /*descr=*/nullptr, /*parent=*/handle(), /*convert=*/true, /*none=*/false);
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
if (!a.value) {
|
| 495 |
+
#if defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 496 |
+
std::string descr("'");
|
| 497 |
+
if (a.name) {
|
| 498 |
+
descr += std::string(a.name) + ": ";
|
| 499 |
+
}
|
| 500 |
+
descr += a.type + "'";
|
| 501 |
+
if (r->is_method) {
|
| 502 |
+
if (r->name) {
|
| 503 |
+
descr += " in method '" + (std::string) str(r->scope) + "."
|
| 504 |
+
+ (std::string) r->name + "'";
|
| 505 |
+
} else {
|
| 506 |
+
descr += " in method of '" + (std::string) str(r->scope) + "'";
|
| 507 |
+
}
|
| 508 |
+
} else if (r->name) {
|
| 509 |
+
descr += " in function '" + (std::string) r->name + "'";
|
| 510 |
+
}
|
| 511 |
+
pybind11_fail("arg(): could not convert default argument " + descr
|
| 512 |
+
+ " into a Python object (type not registered yet?)");
|
| 513 |
+
#else
|
| 514 |
+
pybind11_fail("arg(): could not convert default argument "
|
| 515 |
+
"into a Python object (type not registered yet?). "
|
| 516 |
+
"#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for "
|
| 517 |
+
"more information.");
|
| 518 |
+
#endif
|
| 519 |
+
}
|
| 520 |
+
r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none);
|
| 521 |
+
|
| 522 |
+
check_kw_only_arg(a, r);
|
| 523 |
+
}
|
| 524 |
+
};
|
| 525 |
+
|
| 526 |
+
/// Process a keyword-only-arguments-follow pseudo argument
|
| 527 |
+
template <>
|
| 528 |
+
struct process_attribute<kw_only> : process_attribute_default<kw_only> {
|
| 529 |
+
static void init(const kw_only &, function_record *r) {
|
| 530 |
+
append_self_arg_if_needed(r);
|
| 531 |
+
if (r->has_args && r->nargs_pos != static_cast<std::uint16_t>(r->args.size())) {
|
| 532 |
+
pybind11_fail("Mismatched args() and kw_only(): they must occur at the same relative "
|
| 533 |
+
"argument location (or omit kw_only() entirely)");
|
| 534 |
+
}
|
| 535 |
+
r->nargs_pos = static_cast<std::uint16_t>(r->args.size());
|
| 536 |
+
}
|
| 537 |
+
};
|
| 538 |
+
|
| 539 |
+
/// Process a positional-only-argument maker
|
| 540 |
+
template <>
|
| 541 |
+
struct process_attribute<pos_only> : process_attribute_default<pos_only> {
|
| 542 |
+
static void init(const pos_only &, function_record *r) {
|
| 543 |
+
append_self_arg_if_needed(r);
|
| 544 |
+
r->nargs_pos_only = static_cast<std::uint16_t>(r->args.size());
|
| 545 |
+
if (r->nargs_pos_only > r->nargs_pos) {
|
| 546 |
+
pybind11_fail("pos_only(): cannot follow a py::args() argument");
|
| 547 |
+
}
|
| 548 |
+
// It also can't follow a kw_only, but a static_assert in pybind11.h checks that
|
| 549 |
+
}
|
| 550 |
+
};
|
| 551 |
+
|
| 552 |
+
/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees
|
| 553 |
+
/// that)
|
| 554 |
+
template <typename T>
|
| 555 |
+
struct process_attribute<T, enable_if_t<is_pyobject<T>::value>>
|
| 556 |
+
: process_attribute_default<handle> {
|
| 557 |
+
static void init(const handle &h, type_record *r) { r->bases.append(h); }
|
| 558 |
+
};
|
| 559 |
+
|
| 560 |
+
/// Process a parent class attribute (deprecated, does not support multiple inheritance)
|
| 561 |
+
template <typename T>
|
| 562 |
+
struct process_attribute<base<T>> : process_attribute_default<base<T>> {
|
| 563 |
+
static void init(const base<T> &, type_record *r) { r->add_base(typeid(T), nullptr); }
|
| 564 |
+
};
|
| 565 |
+
|
| 566 |
+
/// Process a multiple inheritance attribute
|
| 567 |
+
template <>
|
| 568 |
+
struct process_attribute<multiple_inheritance> : process_attribute_default<multiple_inheritance> {
|
| 569 |
+
static void init(const multiple_inheritance &, type_record *r) {
|
| 570 |
+
r->multiple_inheritance = true;
|
| 571 |
+
}
|
| 572 |
+
};
|
| 573 |
+
|
| 574 |
+
template <>
|
| 575 |
+
struct process_attribute<dynamic_attr> : process_attribute_default<dynamic_attr> {
|
| 576 |
+
static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; }
|
| 577 |
+
};
|
| 578 |
+
|
| 579 |
+
template <>
|
| 580 |
+
struct process_attribute<custom_type_setup> {
|
| 581 |
+
static void init(const custom_type_setup &value, type_record *r) {
|
| 582 |
+
r->custom_type_setup_callback = value.value;
|
| 583 |
+
}
|
| 584 |
+
};
|
| 585 |
+
|
| 586 |
+
template <>
|
| 587 |
+
struct process_attribute<is_final> : process_attribute_default<is_final> {
|
| 588 |
+
static void init(const is_final &, type_record *r) { r->is_final = true; }
|
| 589 |
+
};
|
| 590 |
+
|
| 591 |
+
template <>
|
| 592 |
+
struct process_attribute<buffer_protocol> : process_attribute_default<buffer_protocol> {
|
| 593 |
+
static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; }
|
| 594 |
+
};
|
| 595 |
+
|
| 596 |
+
template <>
|
| 597 |
+
struct process_attribute<metaclass> : process_attribute_default<metaclass> {
|
| 598 |
+
static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; }
|
| 599 |
+
};
|
| 600 |
+
|
| 601 |
+
template <>
|
| 602 |
+
struct process_attribute<module_local> : process_attribute_default<module_local> {
|
| 603 |
+
static void init(const module_local &l, type_record *r) { r->module_local = l.value; }
|
| 604 |
+
};
|
| 605 |
+
|
| 606 |
+
/// Process a 'prepend' attribute, putting this at the beginning of the overload chain
|
| 607 |
+
template <>
|
| 608 |
+
struct process_attribute<prepend> : process_attribute_default<prepend> {
|
| 609 |
+
static void init(const prepend &, function_record *r) { r->prepend = true; }
|
| 610 |
+
};
|
| 611 |
+
|
| 612 |
+
/// Process an 'arithmetic' attribute for enums (does nothing here)
|
| 613 |
+
template <>
|
| 614 |
+
struct process_attribute<arithmetic> : process_attribute_default<arithmetic> {};
|
| 615 |
+
|
| 616 |
+
template <typename... Ts>
|
| 617 |
+
struct process_attribute<call_guard<Ts...>> : process_attribute_default<call_guard<Ts...>> {};
|
| 618 |
+
|
| 619 |
+
/**
|
| 620 |
+
* Process a keep_alive call policy -- invokes keep_alive_impl during the
|
| 621 |
+
* pre-call handler if both Nurse, Patient != 0 and use the post-call handler
|
| 622 |
+
* otherwise
|
| 623 |
+
*/
|
| 624 |
+
template <size_t Nurse, size_t Patient>
|
| 625 |
+
struct process_attribute<keep_alive<Nurse, Patient>>
|
| 626 |
+
: public process_attribute_default<keep_alive<Nurse, Patient>> {
|
| 627 |
+
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
|
| 628 |
+
static void precall(function_call &call) {
|
| 629 |
+
keep_alive_impl(Nurse, Patient, call, handle());
|
| 630 |
+
}
|
| 631 |
+
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
|
| 632 |
+
static void postcall(function_call &, handle) {}
|
| 633 |
+
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
|
| 634 |
+
static void precall(function_call &) {}
|
| 635 |
+
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
|
| 636 |
+
static void postcall(function_call &call, handle ret) {
|
| 637 |
+
keep_alive_impl(Nurse, Patient, call, ret);
|
| 638 |
+
}
|
| 639 |
+
};
|
| 640 |
+
|
| 641 |
+
/// Recursively iterate over variadic template arguments
|
| 642 |
+
template <typename... Args>
|
| 643 |
+
struct process_attributes {
|
| 644 |
+
static void init(const Args &...args, function_record *r) {
|
| 645 |
+
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(r);
|
| 646 |
+
PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(r);
|
| 647 |
+
using expander = int[];
|
| 648 |
+
(void) expander{
|
| 649 |
+
0, ((void) process_attribute<typename std::decay<Args>::type>::init(args, r), 0)...};
|
| 650 |
+
}
|
| 651 |
+
static void init(const Args &...args, type_record *r) {
|
| 652 |
+
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(r);
|
| 653 |
+
PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(r);
|
| 654 |
+
using expander = int[];
|
| 655 |
+
(void) expander{0,
|
| 656 |
+
(process_attribute<typename std::decay<Args>::type>::init(args, r), 0)...};
|
| 657 |
+
}
|
| 658 |
+
static void precall(function_call &call) {
|
| 659 |
+
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(call);
|
| 660 |
+
using expander = int[];
|
| 661 |
+
(void) expander{0,
|
| 662 |
+
(process_attribute<typename std::decay<Args>::type>::precall(call), 0)...};
|
| 663 |
+
}
|
| 664 |
+
static void postcall(function_call &call, handle fn_ret) {
|
| 665 |
+
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(call, fn_ret);
|
| 666 |
+
PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(fn_ret);
|
| 667 |
+
using expander = int[];
|
| 668 |
+
(void) expander{
|
| 669 |
+
0, (process_attribute<typename std::decay<Args>::type>::postcall(call, fn_ret), 0)...};
|
| 670 |
+
}
|
| 671 |
+
};
|
| 672 |
+
|
| 673 |
+
template <typename T>
|
| 674 |
+
using is_call_guard = is_instantiation<call_guard, T>;
|
| 675 |
+
|
| 676 |
+
/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found)
|
| 677 |
+
template <typename... Extra>
|
| 678 |
+
using extract_guard_t = typename exactly_one_t<is_call_guard, call_guard<>, Extra...>::type;
|
| 679 |
+
|
| 680 |
+
/// Check the number of named arguments at compile time
|
| 681 |
+
template <typename... Extra,
|
| 682 |
+
size_t named = constexpr_sum(std::is_base_of<arg, Extra>::value...),
|
| 683 |
+
size_t self = constexpr_sum(std::is_same<is_method, Extra>::value...)>
|
| 684 |
+
constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) {
|
| 685 |
+
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(nargs, has_args, has_kwargs);
|
| 686 |
+
return named == 0 || (self + named + size_t(has_args) + size_t(has_kwargs)) == nargs;
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 690 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/buffer_info.h
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/buffer_info.h: Python buffer object interface
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "detail/common.h"
|
| 13 |
+
|
| 14 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 15 |
+
|
| 16 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 17 |
+
|
| 18 |
+
// Default, C-style strides
|
| 19 |
+
inline std::vector<ssize_t> c_strides(const std::vector<ssize_t> &shape, ssize_t itemsize) {
|
| 20 |
+
auto ndim = shape.size();
|
| 21 |
+
std::vector<ssize_t> strides(ndim, itemsize);
|
| 22 |
+
if (ndim > 0) {
|
| 23 |
+
for (size_t i = ndim - 1; i > 0; --i) {
|
| 24 |
+
strides[i - 1] = strides[i] * shape[i];
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
return strides;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// F-style strides; default when constructing an array_t with `ExtraFlags & f_style`
|
| 31 |
+
inline std::vector<ssize_t> f_strides(const std::vector<ssize_t> &shape, ssize_t itemsize) {
|
| 32 |
+
auto ndim = shape.size();
|
| 33 |
+
std::vector<ssize_t> strides(ndim, itemsize);
|
| 34 |
+
for (size_t i = 1; i < ndim; ++i) {
|
| 35 |
+
strides[i] = strides[i - 1] * shape[i - 1];
|
| 36 |
+
}
|
| 37 |
+
return strides;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
template <typename T, typename SFINAE = void>
|
| 41 |
+
struct compare_buffer_info;
|
| 42 |
+
|
| 43 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 44 |
+
|
| 45 |
+
/// Information record describing a Python buffer object
|
| 46 |
+
struct buffer_info {
|
| 47 |
+
void *ptr = nullptr; // Pointer to the underlying storage
|
| 48 |
+
ssize_t itemsize = 0; // Size of individual items in bytes
|
| 49 |
+
ssize_t size = 0; // Total number of entries
|
| 50 |
+
std::string format; // For homogeneous buffers, this should be set to
|
| 51 |
+
// format_descriptor<T>::format()
|
| 52 |
+
ssize_t ndim = 0; // Number of dimensions
|
| 53 |
+
std::vector<ssize_t> shape; // Shape of the tensor (1 entry per dimension)
|
| 54 |
+
std::vector<ssize_t> strides; // Number of bytes between adjacent entries
|
| 55 |
+
// (for each per dimension)
|
| 56 |
+
bool readonly = false; // flag to indicate if the underlying storage may be written to
|
| 57 |
+
|
| 58 |
+
buffer_info() = default;
|
| 59 |
+
|
| 60 |
+
buffer_info(void *ptr,
|
| 61 |
+
ssize_t itemsize,
|
| 62 |
+
const std::string &format,
|
| 63 |
+
ssize_t ndim,
|
| 64 |
+
detail::any_container<ssize_t> shape_in,
|
| 65 |
+
detail::any_container<ssize_t> strides_in,
|
| 66 |
+
bool readonly = false)
|
| 67 |
+
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
|
| 68 |
+
shape(std::move(shape_in)), strides(std::move(strides_in)), readonly(readonly) {
|
| 69 |
+
if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) {
|
| 70 |
+
pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
|
| 71 |
+
}
|
| 72 |
+
for (size_t i = 0; i < (size_t) ndim; ++i) {
|
| 73 |
+
size *= shape[i];
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
template <typename T>
|
| 78 |
+
buffer_info(T *ptr,
|
| 79 |
+
detail::any_container<ssize_t> shape_in,
|
| 80 |
+
detail::any_container<ssize_t> strides_in,
|
| 81 |
+
bool readonly = false)
|
| 82 |
+
: buffer_info(private_ctr_tag(),
|
| 83 |
+
ptr,
|
| 84 |
+
sizeof(T),
|
| 85 |
+
format_descriptor<T>::format(),
|
| 86 |
+
static_cast<ssize_t>(shape_in->size()),
|
| 87 |
+
std::move(shape_in),
|
| 88 |
+
std::move(strides_in),
|
| 89 |
+
readonly) {}
|
| 90 |
+
|
| 91 |
+
buffer_info(void *ptr,
|
| 92 |
+
ssize_t itemsize,
|
| 93 |
+
const std::string &format,
|
| 94 |
+
ssize_t size,
|
| 95 |
+
bool readonly = false)
|
| 96 |
+
: buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}, readonly) {}
|
| 97 |
+
|
| 98 |
+
template <typename T>
|
| 99 |
+
buffer_info(T *ptr, ssize_t size, bool readonly = false)
|
| 100 |
+
: buffer_info(ptr, sizeof(T), format_descriptor<T>::format(), size, readonly) {}
|
| 101 |
+
|
| 102 |
+
template <typename T>
|
| 103 |
+
buffer_info(const T *ptr, ssize_t size, bool readonly = true)
|
| 104 |
+
: buffer_info(
|
| 105 |
+
const_cast<T *>(ptr), sizeof(T), format_descriptor<T>::format(), size, readonly) {}
|
| 106 |
+
|
| 107 |
+
explicit buffer_info(Py_buffer *view, bool ownview = true)
|
| 108 |
+
: buffer_info(
|
| 109 |
+
view->buf,
|
| 110 |
+
view->itemsize,
|
| 111 |
+
view->format,
|
| 112 |
+
view->ndim,
|
| 113 |
+
{view->shape, view->shape + view->ndim},
|
| 114 |
+
/* Though buffer::request() requests PyBUF_STRIDES, ctypes objects
|
| 115 |
+
* ignore this flag and return a view with NULL strides.
|
| 116 |
+
* When strides are NULL, build them manually. */
|
| 117 |
+
view->strides
|
| 118 |
+
? std::vector<ssize_t>(view->strides, view->strides + view->ndim)
|
| 119 |
+
: detail::c_strides({view->shape, view->shape + view->ndim}, view->itemsize),
|
| 120 |
+
(view->readonly != 0)) {
|
| 121 |
+
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
|
| 122 |
+
this->m_view = view;
|
| 123 |
+
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
|
| 124 |
+
this->ownview = ownview;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
buffer_info(const buffer_info &) = delete;
|
| 128 |
+
buffer_info &operator=(const buffer_info &) = delete;
|
| 129 |
+
|
| 130 |
+
buffer_info(buffer_info &&other) noexcept { (*this) = std::move(other); }
|
| 131 |
+
|
| 132 |
+
buffer_info &operator=(buffer_info &&rhs) noexcept {
|
| 133 |
+
ptr = rhs.ptr;
|
| 134 |
+
itemsize = rhs.itemsize;
|
| 135 |
+
size = rhs.size;
|
| 136 |
+
format = std::move(rhs.format);
|
| 137 |
+
ndim = rhs.ndim;
|
| 138 |
+
shape = std::move(rhs.shape);
|
| 139 |
+
strides = std::move(rhs.strides);
|
| 140 |
+
std::swap(m_view, rhs.m_view);
|
| 141 |
+
std::swap(ownview, rhs.ownview);
|
| 142 |
+
readonly = rhs.readonly;
|
| 143 |
+
return *this;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
~buffer_info() {
|
| 147 |
+
if (m_view && ownview) {
|
| 148 |
+
PyBuffer_Release(m_view);
|
| 149 |
+
delete m_view;
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
Py_buffer *view() const { return m_view; }
|
| 154 |
+
Py_buffer *&view() { return m_view; }
|
| 155 |
+
|
| 156 |
+
/* True if the buffer item type is equivalent to `T`. */
|
| 157 |
+
// To define "equivalent" by example:
|
| 158 |
+
// `buffer_info::item_type_is_equivalent_to<int>(b)` and
|
| 159 |
+
// `buffer_info::item_type_is_equivalent_to<long>(b)` may both be true
|
| 160 |
+
// on some platforms, but `int` and `unsigned` will never be equivalent.
|
| 161 |
+
// For the ground truth, please inspect `detail::compare_buffer_info<>`.
|
| 162 |
+
template <typename T>
|
| 163 |
+
bool item_type_is_equivalent_to() const {
|
| 164 |
+
return detail::compare_buffer_info<T>::compare(*this);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
private:
|
| 168 |
+
struct private_ctr_tag {};
|
| 169 |
+
|
| 170 |
+
buffer_info(private_ctr_tag,
|
| 171 |
+
void *ptr,
|
| 172 |
+
ssize_t itemsize,
|
| 173 |
+
const std::string &format,
|
| 174 |
+
ssize_t ndim,
|
| 175 |
+
detail::any_container<ssize_t> &&shape_in,
|
| 176 |
+
detail::any_container<ssize_t> &&strides_in,
|
| 177 |
+
bool readonly)
|
| 178 |
+
: buffer_info(
|
| 179 |
+
ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in), readonly) {}
|
| 180 |
+
|
| 181 |
+
Py_buffer *m_view = nullptr;
|
| 182 |
+
bool ownview = false;
|
| 183 |
+
};
|
| 184 |
+
|
| 185 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 186 |
+
|
| 187 |
+
template <typename T, typename SFINAE>
|
| 188 |
+
struct compare_buffer_info {
|
| 189 |
+
static bool compare(const buffer_info &b) {
|
| 190 |
+
// NOLINTNEXTLINE(bugprone-sizeof-expression) Needed for `PyObject *`
|
| 191 |
+
return b.format == format_descriptor<T>::format() && b.itemsize == (ssize_t) sizeof(T);
|
| 192 |
+
}
|
| 193 |
+
};
|
| 194 |
+
|
| 195 |
+
template <typename T>
|
| 196 |
+
struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
|
| 197 |
+
static bool compare(const buffer_info &b) {
|
| 198 |
+
return (size_t) b.itemsize == sizeof(T)
|
| 199 |
+
&& (b.format == format_descriptor<T>::value
|
| 200 |
+
|| ((sizeof(T) == sizeof(long))
|
| 201 |
+
&& b.format == (std::is_unsigned<T>::value ? "L" : "l"))
|
| 202 |
+
|| ((sizeof(T) == sizeof(size_t))
|
| 203 |
+
&& b.format == (std::is_unsigned<T>::value ? "N" : "n")));
|
| 204 |
+
}
|
| 205 |
+
};
|
| 206 |
+
|
| 207 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 208 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/cast.h
ADDED
|
@@ -0,0 +1,1855 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/cast.h: Partial template specializations to cast between
|
| 3 |
+
C++ and Python types
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 6 |
+
|
| 7 |
+
All rights reserved. Use of this source code is governed by a
|
| 8 |
+
BSD-style license that can be found in the LICENSE file.
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include "detail/common.h"
|
| 14 |
+
#include "detail/descr.h"
|
| 15 |
+
#include "detail/type_caster_base.h"
|
| 16 |
+
#include "detail/typeid.h"
|
| 17 |
+
#include "pytypes.h"
|
| 18 |
+
|
| 19 |
+
#include <array>
|
| 20 |
+
#include <cstring>
|
| 21 |
+
#include <functional>
|
| 22 |
+
#include <iosfwd>
|
| 23 |
+
#include <iterator>
|
| 24 |
+
#include <memory>
|
| 25 |
+
#include <string>
|
| 26 |
+
#include <tuple>
|
| 27 |
+
#include <type_traits>
|
| 28 |
+
#include <utility>
|
| 29 |
+
#include <vector>
|
| 30 |
+
|
| 31 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 32 |
+
|
| 33 |
+
PYBIND11_WARNING_DISABLE_MSVC(4127)
|
| 34 |
+
|
| 35 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 36 |
+
|
| 37 |
+
template <typename type, typename SFINAE = void>
|
| 38 |
+
class type_caster : public type_caster_base<type> {};
|
| 39 |
+
template <typename type>
|
| 40 |
+
using make_caster = type_caster<intrinsic_t<type>>;
|
| 41 |
+
|
| 42 |
+
// Shortcut for calling a caster's `cast_op_type` cast operator for casting a type_caster to a T
|
| 43 |
+
template <typename T>
|
| 44 |
+
typename make_caster<T>::template cast_op_type<T> cast_op(make_caster<T> &caster) {
|
| 45 |
+
using result_t = typename make_caster<T>::template cast_op_type<T>; // See PR #4893
|
| 46 |
+
return caster.operator result_t();
|
| 47 |
+
}
|
| 48 |
+
template <typename T>
|
| 49 |
+
typename make_caster<T>::template cast_op_type<typename std::add_rvalue_reference<T>::type>
|
| 50 |
+
cast_op(make_caster<T> &&caster) {
|
| 51 |
+
using result_t = typename make_caster<T>::template cast_op_type<
|
| 52 |
+
typename std::add_rvalue_reference<T>::type>; // See PR #4893
|
| 53 |
+
return std::move(caster).operator result_t();
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
template <typename type>
|
| 57 |
+
class type_caster<std::reference_wrapper<type>> {
|
| 58 |
+
private:
|
| 59 |
+
using caster_t = make_caster<type>;
|
| 60 |
+
caster_t subcaster;
|
| 61 |
+
using reference_t = type &;
|
| 62 |
+
using subcaster_cast_op_type = typename caster_t::template cast_op_type<reference_t>;
|
| 63 |
+
|
| 64 |
+
static_assert(
|
| 65 |
+
std::is_same<typename std::remove_const<type>::type &, subcaster_cast_op_type>::value
|
| 66 |
+
|| std::is_same<reference_t, subcaster_cast_op_type>::value,
|
| 67 |
+
"std::reference_wrapper<T> caster requires T to have a caster with an "
|
| 68 |
+
"`operator T &()` or `operator const T &()`");
|
| 69 |
+
|
| 70 |
+
public:
|
| 71 |
+
bool load(handle src, bool convert) { return subcaster.load(src, convert); }
|
| 72 |
+
static constexpr auto name = caster_t::name;
|
| 73 |
+
static handle
|
| 74 |
+
cast(const std::reference_wrapper<type> &src, return_value_policy policy, handle parent) {
|
| 75 |
+
// It is definitely wrong to take ownership of this pointer, so mask that rvp
|
| 76 |
+
if (policy == return_value_policy::take_ownership
|
| 77 |
+
|| policy == return_value_policy::automatic) {
|
| 78 |
+
policy = return_value_policy::automatic_reference;
|
| 79 |
+
}
|
| 80 |
+
return caster_t::cast(&src.get(), policy, parent);
|
| 81 |
+
}
|
| 82 |
+
template <typename T>
|
| 83 |
+
using cast_op_type = std::reference_wrapper<type>;
|
| 84 |
+
explicit operator std::reference_wrapper<type>() { return cast_op<type &>(subcaster); }
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
#define PYBIND11_TYPE_CASTER(type, py_name) \
|
| 88 |
+
protected: \
|
| 89 |
+
type value; \
|
| 90 |
+
\
|
| 91 |
+
public: \
|
| 92 |
+
static constexpr auto name = py_name; \
|
| 93 |
+
template <typename T_, \
|
| 94 |
+
::pybind11::detail::enable_if_t< \
|
| 95 |
+
std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
|
| 96 |
+
int> \
|
| 97 |
+
= 0> \
|
| 98 |
+
static ::pybind11::handle cast( \
|
| 99 |
+
T_ *src, ::pybind11::return_value_policy policy, ::pybind11::handle parent) { \
|
| 100 |
+
if (!src) \
|
| 101 |
+
return ::pybind11::none().release(); \
|
| 102 |
+
if (policy == ::pybind11::return_value_policy::take_ownership) { \
|
| 103 |
+
auto h = cast(std::move(*src), policy, parent); \
|
| 104 |
+
delete src; \
|
| 105 |
+
return h; \
|
| 106 |
+
} \
|
| 107 |
+
return cast(*src, policy, parent); \
|
| 108 |
+
} \
|
| 109 |
+
operator type *() { return &value; } /* NOLINT(bugprone-macro-parentheses) */ \
|
| 110 |
+
operator type &() { return value; } /* NOLINT(bugprone-macro-parentheses) */ \
|
| 111 |
+
operator type &&() && { return std::move(value); } /* NOLINT(bugprone-macro-parentheses) */ \
|
| 112 |
+
template <typename T_> \
|
| 113 |
+
using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>
|
| 114 |
+
|
| 115 |
+
template <typename CharT>
|
| 116 |
+
using is_std_char_type = any_of<std::is_same<CharT, char>, /* std::string */
|
| 117 |
+
#if defined(PYBIND11_HAS_U8STRING)
|
| 118 |
+
std::is_same<CharT, char8_t>, /* std::u8string */
|
| 119 |
+
#endif
|
| 120 |
+
std::is_same<CharT, char16_t>, /* std::u16string */
|
| 121 |
+
std::is_same<CharT, char32_t>, /* std::u32string */
|
| 122 |
+
std::is_same<CharT, wchar_t> /* std::wstring */
|
| 123 |
+
>;
|
| 124 |
+
|
| 125 |
+
template <typename T>
|
| 126 |
+
struct type_caster<T, enable_if_t<std::is_arithmetic<T>::value && !is_std_char_type<T>::value>> {
|
| 127 |
+
using _py_type_0 = conditional_t<sizeof(T) <= sizeof(long), long, long long>;
|
| 128 |
+
using _py_type_1 = conditional_t<std::is_signed<T>::value,
|
| 129 |
+
_py_type_0,
|
| 130 |
+
typename std::make_unsigned<_py_type_0>::type>;
|
| 131 |
+
using py_type = conditional_t<std::is_floating_point<T>::value, double, _py_type_1>;
|
| 132 |
+
|
| 133 |
+
public:
|
| 134 |
+
bool load(handle src, bool convert) {
|
| 135 |
+
py_type py_value;
|
| 136 |
+
|
| 137 |
+
if (!src) {
|
| 138 |
+
return false;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
#if !defined(PYPY_VERSION)
|
| 142 |
+
auto index_check = [](PyObject *o) { return PyIndex_Check(o); };
|
| 143 |
+
#else
|
| 144 |
+
// In PyPy 7.3.3, `PyIndex_Check` is implemented by calling `__index__`,
|
| 145 |
+
// while CPython only considers the existence of `nb_index`/`__index__`.
|
| 146 |
+
auto index_check = [](PyObject *o) { return hasattr(o, "__index__"); };
|
| 147 |
+
#endif
|
| 148 |
+
|
| 149 |
+
if (std::is_floating_point<T>::value) {
|
| 150 |
+
if (convert || PyFloat_Check(src.ptr())) {
|
| 151 |
+
py_value = (py_type) PyFloat_AsDouble(src.ptr());
|
| 152 |
+
} else {
|
| 153 |
+
return false;
|
| 154 |
+
}
|
| 155 |
+
} else if (PyFloat_Check(src.ptr())
|
| 156 |
+
|| (!convert && !PYBIND11_LONG_CHECK(src.ptr()) && !index_check(src.ptr()))) {
|
| 157 |
+
return false;
|
| 158 |
+
} else {
|
| 159 |
+
handle src_or_index = src;
|
| 160 |
+
// PyPy: 7.3.7's 3.8 does not implement PyLong_*'s __index__ calls.
|
| 161 |
+
#if PY_VERSION_HEX < 0x03080000 || defined(PYPY_VERSION)
|
| 162 |
+
object index;
|
| 163 |
+
if (!PYBIND11_LONG_CHECK(src.ptr())) { // So: index_check(src.ptr())
|
| 164 |
+
index = reinterpret_steal<object>(PyNumber_Index(src.ptr()));
|
| 165 |
+
if (!index) {
|
| 166 |
+
PyErr_Clear();
|
| 167 |
+
if (!convert)
|
| 168 |
+
return false;
|
| 169 |
+
} else {
|
| 170 |
+
src_or_index = index;
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
#endif
|
| 174 |
+
if (std::is_unsigned<py_type>::value) {
|
| 175 |
+
py_value = as_unsigned<py_type>(src_or_index.ptr());
|
| 176 |
+
} else { // signed integer:
|
| 177 |
+
py_value = sizeof(T) <= sizeof(long)
|
| 178 |
+
? (py_type) PyLong_AsLong(src_or_index.ptr())
|
| 179 |
+
: (py_type) PYBIND11_LONG_AS_LONGLONG(src_or_index.ptr());
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
// Python API reported an error
|
| 184 |
+
bool py_err = py_value == (py_type) -1 && PyErr_Occurred();
|
| 185 |
+
|
| 186 |
+
// Check to see if the conversion is valid (integers should match exactly)
|
| 187 |
+
// Signed/unsigned checks happen elsewhere
|
| 188 |
+
if (py_err
|
| 189 |
+
|| (std::is_integral<T>::value && sizeof(py_type) != sizeof(T)
|
| 190 |
+
&& py_value != (py_type) (T) py_value)) {
|
| 191 |
+
PyErr_Clear();
|
| 192 |
+
if (py_err && convert && (PyNumber_Check(src.ptr()) != 0)) {
|
| 193 |
+
auto tmp = reinterpret_steal<object>(std::is_floating_point<T>::value
|
| 194 |
+
? PyNumber_Float(src.ptr())
|
| 195 |
+
: PyNumber_Long(src.ptr()));
|
| 196 |
+
PyErr_Clear();
|
| 197 |
+
return load(tmp, false);
|
| 198 |
+
}
|
| 199 |
+
return false;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
value = (T) py_value;
|
| 203 |
+
return true;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
template <typename U = T>
|
| 207 |
+
static typename std::enable_if<std::is_floating_point<U>::value, handle>::type
|
| 208 |
+
cast(U src, return_value_policy /* policy */, handle /* parent */) {
|
| 209 |
+
return PyFloat_FromDouble((double) src);
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
template <typename U = T>
|
| 213 |
+
static typename std::enable_if<!std::is_floating_point<U>::value && std::is_signed<U>::value
|
| 214 |
+
&& (sizeof(U) <= sizeof(long)),
|
| 215 |
+
handle>::type
|
| 216 |
+
cast(U src, return_value_policy /* policy */, handle /* parent */) {
|
| 217 |
+
return PYBIND11_LONG_FROM_SIGNED((long) src);
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
template <typename U = T>
|
| 221 |
+
static typename std::enable_if<!std::is_floating_point<U>::value && std::is_unsigned<U>::value
|
| 222 |
+
&& (sizeof(U) <= sizeof(unsigned long)),
|
| 223 |
+
handle>::type
|
| 224 |
+
cast(U src, return_value_policy /* policy */, handle /* parent */) {
|
| 225 |
+
return PYBIND11_LONG_FROM_UNSIGNED((unsigned long) src);
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
template <typename U = T>
|
| 229 |
+
static typename std::enable_if<!std::is_floating_point<U>::value && std::is_signed<U>::value
|
| 230 |
+
&& (sizeof(U) > sizeof(long)),
|
| 231 |
+
handle>::type
|
| 232 |
+
cast(U src, return_value_policy /* policy */, handle /* parent */) {
|
| 233 |
+
return PyLong_FromLongLong((long long) src);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
template <typename U = T>
|
| 237 |
+
static typename std::enable_if<!std::is_floating_point<U>::value && std::is_unsigned<U>::value
|
| 238 |
+
&& (sizeof(U) > sizeof(unsigned long)),
|
| 239 |
+
handle>::type
|
| 240 |
+
cast(U src, return_value_policy /* policy */, handle /* parent */) {
|
| 241 |
+
return PyLong_FromUnsignedLongLong((unsigned long long) src);
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
PYBIND11_TYPE_CASTER(T, const_name<std::is_integral<T>::value>("int", "float"));
|
| 245 |
+
};
|
| 246 |
+
|
| 247 |
+
template <typename T>
|
| 248 |
+
struct void_caster {
|
| 249 |
+
public:
|
| 250 |
+
bool load(handle src, bool) {
|
| 251 |
+
if (src && src.is_none()) {
|
| 252 |
+
return true;
|
| 253 |
+
}
|
| 254 |
+
return false;
|
| 255 |
+
}
|
| 256 |
+
static handle cast(T, return_value_policy /* policy */, handle /* parent */) {
|
| 257 |
+
return none().release();
|
| 258 |
+
}
|
| 259 |
+
PYBIND11_TYPE_CASTER(T, const_name("None"));
|
| 260 |
+
};
|
| 261 |
+
|
| 262 |
+
template <>
|
| 263 |
+
class type_caster<void_type> : public void_caster<void_type> {};
|
| 264 |
+
|
| 265 |
+
template <>
|
| 266 |
+
class type_caster<void> : public type_caster<void_type> {
|
| 267 |
+
public:
|
| 268 |
+
using type_caster<void_type>::cast;
|
| 269 |
+
|
| 270 |
+
bool load(handle h, bool) {
|
| 271 |
+
if (!h) {
|
| 272 |
+
return false;
|
| 273 |
+
}
|
| 274 |
+
if (h.is_none()) {
|
| 275 |
+
value = nullptr;
|
| 276 |
+
return true;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
/* Check if this is a capsule */
|
| 280 |
+
if (isinstance<capsule>(h)) {
|
| 281 |
+
value = reinterpret_borrow<capsule>(h);
|
| 282 |
+
return true;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
/* Check if this is a C++ type */
|
| 286 |
+
const auto &bases = all_type_info((PyTypeObject *) type::handle_of(h).ptr());
|
| 287 |
+
if (bases.size() == 1) { // Only allowing loading from a single-value type
|
| 288 |
+
value = values_and_holders(reinterpret_cast<instance *>(h.ptr())).begin()->value_ptr();
|
| 289 |
+
return true;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
/* Fail */
|
| 293 |
+
return false;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
static handle cast(const void *ptr, return_value_policy /* policy */, handle /* parent */) {
|
| 297 |
+
if (ptr) {
|
| 298 |
+
return capsule(ptr).release();
|
| 299 |
+
}
|
| 300 |
+
return none().release();
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
template <typename T>
|
| 304 |
+
using cast_op_type = void *&;
|
| 305 |
+
explicit operator void *&() { return value; }
|
| 306 |
+
static constexpr auto name = const_name("capsule");
|
| 307 |
+
|
| 308 |
+
private:
|
| 309 |
+
void *value = nullptr;
|
| 310 |
+
};
|
| 311 |
+
|
| 312 |
+
template <>
|
| 313 |
+
class type_caster<std::nullptr_t> : public void_caster<std::nullptr_t> {};
|
| 314 |
+
|
| 315 |
+
template <>
|
| 316 |
+
class type_caster<bool> {
|
| 317 |
+
public:
|
| 318 |
+
bool load(handle src, bool convert) {
|
| 319 |
+
if (!src) {
|
| 320 |
+
return false;
|
| 321 |
+
}
|
| 322 |
+
if (src.ptr() == Py_True) {
|
| 323 |
+
value = true;
|
| 324 |
+
return true;
|
| 325 |
+
}
|
| 326 |
+
if (src.ptr() == Py_False) {
|
| 327 |
+
value = false;
|
| 328 |
+
return true;
|
| 329 |
+
}
|
| 330 |
+
if (convert || is_numpy_bool(src)) {
|
| 331 |
+
// (allow non-implicit conversion for numpy booleans), use strncmp
|
| 332 |
+
// since NumPy 1.x had an additional trailing underscore.
|
| 333 |
+
|
| 334 |
+
Py_ssize_t res = -1;
|
| 335 |
+
if (src.is_none()) {
|
| 336 |
+
res = 0; // None is implicitly converted to False
|
| 337 |
+
}
|
| 338 |
+
#if defined(PYPY_VERSION)
|
| 339 |
+
// On PyPy, check that "__bool__" attr exists
|
| 340 |
+
else if (hasattr(src, PYBIND11_BOOL_ATTR)) {
|
| 341 |
+
res = PyObject_IsTrue(src.ptr());
|
| 342 |
+
}
|
| 343 |
+
#else
|
| 344 |
+
// Alternate approach for CPython: this does the same as the above, but optimized
|
| 345 |
+
// using the CPython API so as to avoid an unneeded attribute lookup.
|
| 346 |
+
else if (auto *tp_as_number = src.ptr()->ob_type->tp_as_number) {
|
| 347 |
+
if (PYBIND11_NB_BOOL(tp_as_number)) {
|
| 348 |
+
res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr());
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
#endif
|
| 352 |
+
if (res == 0 || res == 1) {
|
| 353 |
+
value = (res != 0);
|
| 354 |
+
return true;
|
| 355 |
+
}
|
| 356 |
+
PyErr_Clear();
|
| 357 |
+
}
|
| 358 |
+
return false;
|
| 359 |
+
}
|
| 360 |
+
static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) {
|
| 361 |
+
return handle(src ? Py_True : Py_False).inc_ref();
|
| 362 |
+
}
|
| 363 |
+
PYBIND11_TYPE_CASTER(bool, const_name("bool"));
|
| 364 |
+
|
| 365 |
+
private:
|
| 366 |
+
// Test if an object is a NumPy boolean (without fetching the type).
|
| 367 |
+
static inline bool is_numpy_bool(handle object) {
|
| 368 |
+
const char *type_name = Py_TYPE(object.ptr())->tp_name;
|
| 369 |
+
// Name changed to `numpy.bool` in NumPy 2, `numpy.bool_` is needed for 1.x support
|
| 370 |
+
return std::strcmp("numpy.bool", type_name) == 0
|
| 371 |
+
|| std::strcmp("numpy.bool_", type_name) == 0;
|
| 372 |
+
}
|
| 373 |
+
};
|
| 374 |
+
|
| 375 |
+
// Helper class for UTF-{8,16,32} C++ stl strings:
|
| 376 |
+
template <typename StringType, bool IsView = false>
|
| 377 |
+
struct string_caster {
|
| 378 |
+
using CharT = typename StringType::value_type;
|
| 379 |
+
|
| 380 |
+
// Simplify life by being able to assume standard char sizes (the standard only guarantees
|
| 381 |
+
// minimums, but Python requires exact sizes)
|
| 382 |
+
static_assert(!std::is_same<CharT, char>::value || sizeof(CharT) == 1,
|
| 383 |
+
"Unsupported char size != 1");
|
| 384 |
+
#if defined(PYBIND11_HAS_U8STRING)
|
| 385 |
+
static_assert(!std::is_same<CharT, char8_t>::value || sizeof(CharT) == 1,
|
| 386 |
+
"Unsupported char8_t size != 1");
|
| 387 |
+
#endif
|
| 388 |
+
static_assert(!std::is_same<CharT, char16_t>::value || sizeof(CharT) == 2,
|
| 389 |
+
"Unsupported char16_t size != 2");
|
| 390 |
+
static_assert(!std::is_same<CharT, char32_t>::value || sizeof(CharT) == 4,
|
| 391 |
+
"Unsupported char32_t size != 4");
|
| 392 |
+
// wchar_t can be either 16 bits (Windows) or 32 (everywhere else)
|
| 393 |
+
static_assert(!std::is_same<CharT, wchar_t>::value || sizeof(CharT) == 2 || sizeof(CharT) == 4,
|
| 394 |
+
"Unsupported wchar_t size != 2/4");
|
| 395 |
+
static constexpr size_t UTF_N = 8 * sizeof(CharT);
|
| 396 |
+
|
| 397 |
+
bool load(handle src, bool) {
|
| 398 |
+
handle load_src = src;
|
| 399 |
+
if (!src) {
|
| 400 |
+
return false;
|
| 401 |
+
}
|
| 402 |
+
if (!PyUnicode_Check(load_src.ptr())) {
|
| 403 |
+
return load_raw(load_src);
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
// For UTF-8 we avoid the need for a temporary `bytes` object by using
|
| 407 |
+
// `PyUnicode_AsUTF8AndSize`.
|
| 408 |
+
if (UTF_N == 8) {
|
| 409 |
+
Py_ssize_t size = -1;
|
| 410 |
+
const auto *buffer
|
| 411 |
+
= reinterpret_cast<const CharT *>(PyUnicode_AsUTF8AndSize(load_src.ptr(), &size));
|
| 412 |
+
if (!buffer) {
|
| 413 |
+
PyErr_Clear();
|
| 414 |
+
return false;
|
| 415 |
+
}
|
| 416 |
+
value = StringType(buffer, static_cast<size_t>(size));
|
| 417 |
+
return true;
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
auto utfNbytes
|
| 421 |
+
= reinterpret_steal<object>(PyUnicode_AsEncodedString(load_src.ptr(),
|
| 422 |
+
UTF_N == 8 ? "utf-8"
|
| 423 |
+
: UTF_N == 16 ? "utf-16"
|
| 424 |
+
: "utf-32",
|
| 425 |
+
nullptr));
|
| 426 |
+
if (!utfNbytes) {
|
| 427 |
+
PyErr_Clear();
|
| 428 |
+
return false;
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
const auto *buffer
|
| 432 |
+
= reinterpret_cast<const CharT *>(PYBIND11_BYTES_AS_STRING(utfNbytes.ptr()));
|
| 433 |
+
size_t length = (size_t) PYBIND11_BYTES_SIZE(utfNbytes.ptr()) / sizeof(CharT);
|
| 434 |
+
// Skip BOM for UTF-16/32
|
| 435 |
+
if (UTF_N > 8) {
|
| 436 |
+
buffer++;
|
| 437 |
+
length--;
|
| 438 |
+
}
|
| 439 |
+
value = StringType(buffer, length);
|
| 440 |
+
|
| 441 |
+
// If we're loading a string_view we need to keep the encoded Python object alive:
|
| 442 |
+
if (IsView) {
|
| 443 |
+
loader_life_support::add_patient(utfNbytes);
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
return true;
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
static handle
|
| 450 |
+
cast(const StringType &src, return_value_policy /* policy */, handle /* parent */) {
|
| 451 |
+
const char *buffer = reinterpret_cast<const char *>(src.data());
|
| 452 |
+
auto nbytes = ssize_t(src.size() * sizeof(CharT));
|
| 453 |
+
handle s = decode_utfN(buffer, nbytes);
|
| 454 |
+
if (!s) {
|
| 455 |
+
throw error_already_set();
|
| 456 |
+
}
|
| 457 |
+
return s;
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
PYBIND11_TYPE_CASTER(StringType, const_name(PYBIND11_STRING_NAME));
|
| 461 |
+
|
| 462 |
+
private:
|
| 463 |
+
static handle decode_utfN(const char *buffer, ssize_t nbytes) {
|
| 464 |
+
#if !defined(PYPY_VERSION)
|
| 465 |
+
return UTF_N == 8 ? PyUnicode_DecodeUTF8(buffer, nbytes, nullptr)
|
| 466 |
+
: UTF_N == 16 ? PyUnicode_DecodeUTF16(buffer, nbytes, nullptr, nullptr)
|
| 467 |
+
: PyUnicode_DecodeUTF32(buffer, nbytes, nullptr, nullptr);
|
| 468 |
+
#else
|
| 469 |
+
// PyPy segfaults when on PyUnicode_DecodeUTF16 (and possibly on PyUnicode_DecodeUTF32 as
|
| 470 |
+
// well), so bypass the whole thing by just passing the encoding as a string value, which
|
| 471 |
+
// works properly:
|
| 472 |
+
return PyUnicode_Decode(buffer,
|
| 473 |
+
nbytes,
|
| 474 |
+
UTF_N == 8 ? "utf-8"
|
| 475 |
+
: UTF_N == 16 ? "utf-16"
|
| 476 |
+
: "utf-32",
|
| 477 |
+
nullptr);
|
| 478 |
+
#endif
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
// When loading into a std::string or char*, accept a bytes/bytearray object as-is (i.e.
|
| 482 |
+
// without any encoding/decoding attempt). For other C++ char sizes this is a no-op.
|
| 483 |
+
// which supports loading a unicode from a str, doesn't take this path.
|
| 484 |
+
template <typename C = CharT>
|
| 485 |
+
bool load_raw(enable_if_t<std::is_same<C, char>::value, handle> src) {
|
| 486 |
+
if (PYBIND11_BYTES_CHECK(src.ptr())) {
|
| 487 |
+
// We were passed raw bytes; accept it into a std::string or char*
|
| 488 |
+
// without any encoding attempt.
|
| 489 |
+
const char *bytes = PYBIND11_BYTES_AS_STRING(src.ptr());
|
| 490 |
+
if (!bytes) {
|
| 491 |
+
pybind11_fail("Unexpected PYBIND11_BYTES_AS_STRING() failure.");
|
| 492 |
+
}
|
| 493 |
+
value = StringType(bytes, (size_t) PYBIND11_BYTES_SIZE(src.ptr()));
|
| 494 |
+
return true;
|
| 495 |
+
}
|
| 496 |
+
if (PyByteArray_Check(src.ptr())) {
|
| 497 |
+
// We were passed a bytearray; accept it into a std::string or char*
|
| 498 |
+
// without any encoding attempt.
|
| 499 |
+
const char *bytearray = PyByteArray_AsString(src.ptr());
|
| 500 |
+
if (!bytearray) {
|
| 501 |
+
pybind11_fail("Unexpected PyByteArray_AsString() failure.");
|
| 502 |
+
}
|
| 503 |
+
value = StringType(bytearray, (size_t) PyByteArray_Size(src.ptr()));
|
| 504 |
+
return true;
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
return false;
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
template <typename C = CharT>
|
| 511 |
+
bool load_raw(enable_if_t<!std::is_same<C, char>::value, handle>) {
|
| 512 |
+
return false;
|
| 513 |
+
}
|
| 514 |
+
};
|
| 515 |
+
|
| 516 |
+
template <typename CharT, class Traits, class Allocator>
|
| 517 |
+
struct type_caster<std::basic_string<CharT, Traits, Allocator>,
|
| 518 |
+
enable_if_t<is_std_char_type<CharT>::value>>
|
| 519 |
+
: string_caster<std::basic_string<CharT, Traits, Allocator>> {};
|
| 520 |
+
|
| 521 |
+
#ifdef PYBIND11_HAS_STRING_VIEW
|
| 522 |
+
template <typename CharT, class Traits>
|
| 523 |
+
struct type_caster<std::basic_string_view<CharT, Traits>,
|
| 524 |
+
enable_if_t<is_std_char_type<CharT>::value>>
|
| 525 |
+
: string_caster<std::basic_string_view<CharT, Traits>, true> {};
|
| 526 |
+
#endif
|
| 527 |
+
|
| 528 |
+
// Type caster for C-style strings. We basically use a std::string type caster, but also add the
|
| 529 |
+
// ability to use None as a nullptr char* (which the string caster doesn't allow).
|
| 530 |
+
template <typename CharT>
|
| 531 |
+
struct type_caster<CharT, enable_if_t<is_std_char_type<CharT>::value>> {
|
| 532 |
+
using StringType = std::basic_string<CharT>;
|
| 533 |
+
using StringCaster = make_caster<StringType>;
|
| 534 |
+
StringCaster str_caster;
|
| 535 |
+
bool none = false;
|
| 536 |
+
CharT one_char = 0;
|
| 537 |
+
|
| 538 |
+
public:
|
| 539 |
+
bool load(handle src, bool convert) {
|
| 540 |
+
if (!src) {
|
| 541 |
+
return false;
|
| 542 |
+
}
|
| 543 |
+
if (src.is_none()) {
|
| 544 |
+
// Defer accepting None to other overloads (if we aren't in convert mode):
|
| 545 |
+
if (!convert) {
|
| 546 |
+
return false;
|
| 547 |
+
}
|
| 548 |
+
none = true;
|
| 549 |
+
return true;
|
| 550 |
+
}
|
| 551 |
+
return str_caster.load(src, convert);
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
static handle cast(const CharT *src, return_value_policy policy, handle parent) {
|
| 555 |
+
if (src == nullptr) {
|
| 556 |
+
return pybind11::none().release();
|
| 557 |
+
}
|
| 558 |
+
return StringCaster::cast(StringType(src), policy, parent);
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
static handle cast(CharT src, return_value_policy policy, handle parent) {
|
| 562 |
+
if (std::is_same<char, CharT>::value) {
|
| 563 |
+
handle s = PyUnicode_DecodeLatin1((const char *) &src, 1, nullptr);
|
| 564 |
+
if (!s) {
|
| 565 |
+
throw error_already_set();
|
| 566 |
+
}
|
| 567 |
+
return s;
|
| 568 |
+
}
|
| 569 |
+
return StringCaster::cast(StringType(1, src), policy, parent);
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
explicit operator CharT *() {
|
| 573 |
+
return none ? nullptr : const_cast<CharT *>(static_cast<StringType &>(str_caster).c_str());
|
| 574 |
+
}
|
| 575 |
+
explicit operator CharT &() {
|
| 576 |
+
if (none) {
|
| 577 |
+
throw value_error("Cannot convert None to a character");
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
auto &value = static_cast<StringType &>(str_caster);
|
| 581 |
+
size_t str_len = value.size();
|
| 582 |
+
if (str_len == 0) {
|
| 583 |
+
throw value_error("Cannot convert empty string to a character");
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
// If we're in UTF-8 mode, we have two possible failures: one for a unicode character that
|
| 587 |
+
// is too high, and one for multiple unicode characters (caught later), so we need to
|
| 588 |
+
// figure out how long the first encoded character is in bytes to distinguish between these
|
| 589 |
+
// two errors. We also allow want to allow unicode characters U+0080 through U+00FF, as
|
| 590 |
+
// those can fit into a single char value.
|
| 591 |
+
if (StringCaster::UTF_N == 8 && str_len > 1 && str_len <= 4) {
|
| 592 |
+
auto v0 = static_cast<unsigned char>(value[0]);
|
| 593 |
+
// low bits only: 0-127
|
| 594 |
+
// 0b110xxxxx - start of 2-byte sequence
|
| 595 |
+
// 0b1110xxxx - start of 3-byte sequence
|
| 596 |
+
// 0b11110xxx - start of 4-byte sequence
|
| 597 |
+
size_t char0_bytes = (v0 & 0x80) == 0 ? 1
|
| 598 |
+
: (v0 & 0xE0) == 0xC0 ? 2
|
| 599 |
+
: (v0 & 0xF0) == 0xE0 ? 3
|
| 600 |
+
: 4;
|
| 601 |
+
|
| 602 |
+
if (char0_bytes == str_len) {
|
| 603 |
+
// If we have a 128-255 value, we can decode it into a single char:
|
| 604 |
+
if (char0_bytes == 2 && (v0 & 0xFC) == 0xC0) { // 0x110000xx 0x10xxxxxx
|
| 605 |
+
one_char = static_cast<CharT>(((v0 & 3) << 6)
|
| 606 |
+
+ (static_cast<unsigned char>(value[1]) & 0x3F));
|
| 607 |
+
return one_char;
|
| 608 |
+
}
|
| 609 |
+
// Otherwise we have a single character, but it's > U+00FF
|
| 610 |
+
throw value_error("Character code point not in range(0x100)");
|
| 611 |
+
}
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
// UTF-16 is much easier: we can only have a surrogate pair for values above U+FFFF, thus a
|
| 615 |
+
// surrogate pair with total length 2 instantly indicates a range error (but not a "your
|
| 616 |
+
// string was too long" error).
|
| 617 |
+
else if (StringCaster::UTF_N == 16 && str_len == 2) {
|
| 618 |
+
one_char = static_cast<CharT>(value[0]);
|
| 619 |
+
if (one_char >= 0xD800 && one_char < 0xE000) {
|
| 620 |
+
throw value_error("Character code point not in range(0x10000)");
|
| 621 |
+
}
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
if (str_len != 1) {
|
| 625 |
+
throw value_error("Expected a character, but multi-character string found");
|
| 626 |
+
}
|
| 627 |
+
|
| 628 |
+
one_char = value[0];
|
| 629 |
+
return one_char;
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
static constexpr auto name = const_name(PYBIND11_STRING_NAME);
|
| 633 |
+
template <typename _T>
|
| 634 |
+
using cast_op_type = pybind11::detail::cast_op_type<_T>;
|
| 635 |
+
};
|
| 636 |
+
|
| 637 |
+
// Base implementation for std::tuple and std::pair
|
| 638 |
+
template <template <typename...> class Tuple, typename... Ts>
|
| 639 |
+
class tuple_caster {
|
| 640 |
+
using type = Tuple<Ts...>;
|
| 641 |
+
static constexpr auto size = sizeof...(Ts);
|
| 642 |
+
using indices = make_index_sequence<size>;
|
| 643 |
+
|
| 644 |
+
public:
|
| 645 |
+
bool load(handle src, bool convert) {
|
| 646 |
+
if (!isinstance<sequence>(src)) {
|
| 647 |
+
return false;
|
| 648 |
+
}
|
| 649 |
+
const auto seq = reinterpret_borrow<sequence>(src);
|
| 650 |
+
if (seq.size() != size) {
|
| 651 |
+
return false;
|
| 652 |
+
}
|
| 653 |
+
return load_impl(seq, convert, indices{});
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
template <typename T>
|
| 657 |
+
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
| 658 |
+
return cast_impl(std::forward<T>(src), policy, parent, indices{});
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
// copied from the PYBIND11_TYPE_CASTER macro
|
| 662 |
+
template <typename T>
|
| 663 |
+
static handle cast(T *src, return_value_policy policy, handle parent) {
|
| 664 |
+
if (!src) {
|
| 665 |
+
return none().release();
|
| 666 |
+
}
|
| 667 |
+
if (policy == return_value_policy::take_ownership) {
|
| 668 |
+
auto h = cast(std::move(*src), policy, parent);
|
| 669 |
+
delete src;
|
| 670 |
+
return h;
|
| 671 |
+
}
|
| 672 |
+
return cast(*src, policy, parent);
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
static constexpr auto name = const_name("tuple[")
|
| 676 |
+
+ ::pybind11::detail::concat(make_caster<Ts>::name...)
|
| 677 |
+
+ const_name("]");
|
| 678 |
+
|
| 679 |
+
template <typename T>
|
| 680 |
+
using cast_op_type = type;
|
| 681 |
+
|
| 682 |
+
explicit operator type() & { return implicit_cast(indices{}); }
|
| 683 |
+
explicit operator type() && { return std::move(*this).implicit_cast(indices{}); }
|
| 684 |
+
|
| 685 |
+
protected:
|
| 686 |
+
template <size_t... Is>
|
| 687 |
+
type implicit_cast(index_sequence<Is...>) & {
|
| 688 |
+
return type(cast_op<Ts>(std::get<Is>(subcasters))...);
|
| 689 |
+
}
|
| 690 |
+
template <size_t... Is>
|
| 691 |
+
type implicit_cast(index_sequence<Is...>) && {
|
| 692 |
+
return type(cast_op<Ts>(std::move(std::get<Is>(subcasters)))...);
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
static constexpr bool load_impl(const sequence &, bool, index_sequence<>) { return true; }
|
| 696 |
+
|
| 697 |
+
template <size_t... Is>
|
| 698 |
+
bool load_impl(const sequence &seq, bool convert, index_sequence<Is...>) {
|
| 699 |
+
#ifdef __cpp_fold_expressions
|
| 700 |
+
if ((... || !std::get<Is>(subcasters).load(seq[Is], convert))) {
|
| 701 |
+
return false;
|
| 702 |
+
}
|
| 703 |
+
#else
|
| 704 |
+
for (bool r : {std::get<Is>(subcasters).load(seq[Is], convert)...}) {
|
| 705 |
+
if (!r) {
|
| 706 |
+
return false;
|
| 707 |
+
}
|
| 708 |
+
}
|
| 709 |
+
#endif
|
| 710 |
+
return true;
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
+
/* Implementation: Convert a C++ tuple into a Python tuple */
|
| 714 |
+
template <typename T, size_t... Is>
|
| 715 |
+
static handle
|
| 716 |
+
cast_impl(T &&src, return_value_policy policy, handle parent, index_sequence<Is...>) {
|
| 717 |
+
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(src, policy, parent);
|
| 718 |
+
PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(policy, parent);
|
| 719 |
+
std::array<object, size> entries{{reinterpret_steal<object>(
|
| 720 |
+
make_caster<Ts>::cast(std::get<Is>(std::forward<T>(src)), policy, parent))...}};
|
| 721 |
+
for (const auto &entry : entries) {
|
| 722 |
+
if (!entry) {
|
| 723 |
+
return handle();
|
| 724 |
+
}
|
| 725 |
+
}
|
| 726 |
+
tuple result(size);
|
| 727 |
+
int counter = 0;
|
| 728 |
+
for (auto &entry : entries) {
|
| 729 |
+
PyTuple_SET_ITEM(result.ptr(), counter++, entry.release().ptr());
|
| 730 |
+
}
|
| 731 |
+
return result.release();
|
| 732 |
+
}
|
| 733 |
+
|
| 734 |
+
Tuple<make_caster<Ts>...> subcasters;
|
| 735 |
+
};
|
| 736 |
+
|
| 737 |
+
template <typename T1, typename T2>
|
| 738 |
+
class type_caster<std::pair<T1, T2>> : public tuple_caster<std::pair, T1, T2> {};
|
| 739 |
+
|
| 740 |
+
template <typename... Ts>
|
| 741 |
+
class type_caster<std::tuple<Ts...>> : public tuple_caster<std::tuple, Ts...> {};
|
| 742 |
+
|
| 743 |
+
template <>
|
| 744 |
+
class type_caster<std::tuple<>> : public tuple_caster<std::tuple> {
|
| 745 |
+
public:
|
| 746 |
+
// PEP 484 specifies this syntax for an empty tuple
|
| 747 |
+
static constexpr auto name = const_name("tuple[()]");
|
| 748 |
+
};
|
| 749 |
+
|
| 750 |
+
/// Helper class which abstracts away certain actions. Users can provide specializations for
|
| 751 |
+
/// custom holders, but it's only necessary if the type has a non-standard interface.
|
| 752 |
+
template <typename T>
|
| 753 |
+
struct holder_helper {
|
| 754 |
+
static auto get(const T &p) -> decltype(p.get()) { return p.get(); }
|
| 755 |
+
};
|
| 756 |
+
|
| 757 |
+
/// Type caster for holder types like std::shared_ptr, etc.
|
| 758 |
+
/// The SFINAE hook is provided to help work around the current lack of support
|
| 759 |
+
/// for smart-pointer interoperability. Please consider it an implementation
|
| 760 |
+
/// detail that may change in the future, as formal support for smart-pointer
|
| 761 |
+
/// interoperability is added into pybind11.
|
| 762 |
+
template <typename type, typename holder_type, typename SFINAE = void>
|
| 763 |
+
struct copyable_holder_caster : public type_caster_base<type> {
|
| 764 |
+
public:
|
| 765 |
+
using base = type_caster_base<type>;
|
| 766 |
+
static_assert(std::is_base_of<base, type_caster<type>>::value,
|
| 767 |
+
"Holder classes are only supported for custom types");
|
| 768 |
+
using base::base;
|
| 769 |
+
using base::cast;
|
| 770 |
+
using base::typeinfo;
|
| 771 |
+
using base::value;
|
| 772 |
+
|
| 773 |
+
bool load(handle src, bool convert) {
|
| 774 |
+
return base::template load_impl<copyable_holder_caster<type, holder_type>>(src, convert);
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
explicit operator type *() { return this->value; }
|
| 778 |
+
// static_cast works around compiler error with MSVC 17 and CUDA 10.2
|
| 779 |
+
// see issue #2180
|
| 780 |
+
explicit operator type &() { return *(static_cast<type *>(this->value)); }
|
| 781 |
+
explicit operator holder_type *() { return std::addressof(holder); }
|
| 782 |
+
explicit operator holder_type &() { return holder; }
|
| 783 |
+
|
| 784 |
+
static handle cast(const holder_type &src, return_value_policy, handle) {
|
| 785 |
+
const auto *ptr = holder_helper<holder_type>::get(src);
|
| 786 |
+
return type_caster_base<type>::cast_holder(ptr, &src);
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
protected:
|
| 790 |
+
friend class type_caster_generic;
|
| 791 |
+
void check_holder_compat() {
|
| 792 |
+
if (typeinfo->default_holder) {
|
| 793 |
+
throw cast_error("Unable to load a custom holder type from a default-holder instance");
|
| 794 |
+
}
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
void load_value(value_and_holder &&v_h) {
|
| 798 |
+
if (v_h.holder_constructed()) {
|
| 799 |
+
value = v_h.value_ptr();
|
| 800 |
+
holder = v_h.template holder<holder_type>();
|
| 801 |
+
return;
|
| 802 |
+
}
|
| 803 |
+
throw cast_error("Unable to cast from non-held to held instance (T& to Holder<T>) "
|
| 804 |
+
#if !defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 805 |
+
"(#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for "
|
| 806 |
+
"type information)");
|
| 807 |
+
#else
|
| 808 |
+
"of type '"
|
| 809 |
+
+ type_id<holder_type>() + "''");
|
| 810 |
+
#endif
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
template <typename T = holder_type,
|
| 814 |
+
detail::enable_if_t<!std::is_constructible<T, const T &, type *>::value, int> = 0>
|
| 815 |
+
bool try_implicit_casts(handle, bool) {
|
| 816 |
+
return false;
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
template <typename T = holder_type,
|
| 820 |
+
detail::enable_if_t<std::is_constructible<T, const T &, type *>::value, int> = 0>
|
| 821 |
+
bool try_implicit_casts(handle src, bool convert) {
|
| 822 |
+
for (auto &cast : typeinfo->implicit_casts) {
|
| 823 |
+
copyable_holder_caster sub_caster(*cast.first);
|
| 824 |
+
if (sub_caster.load(src, convert)) {
|
| 825 |
+
value = cast.second(sub_caster.value);
|
| 826 |
+
holder = holder_type(sub_caster.holder, (type *) value);
|
| 827 |
+
return true;
|
| 828 |
+
}
|
| 829 |
+
}
|
| 830 |
+
return false;
|
| 831 |
+
}
|
| 832 |
+
|
| 833 |
+
static bool try_direct_conversions(handle) { return false; }
|
| 834 |
+
|
| 835 |
+
holder_type holder;
|
| 836 |
+
};
|
| 837 |
+
|
| 838 |
+
/// Specialize for the common std::shared_ptr, so users don't need to
|
| 839 |
+
template <typename T>
|
| 840 |
+
class type_caster<std::shared_ptr<T>> : public copyable_holder_caster<T, std::shared_ptr<T>> {};
|
| 841 |
+
|
| 842 |
+
/// Type caster for holder types like std::unique_ptr.
|
| 843 |
+
/// Please consider the SFINAE hook an implementation detail, as explained
|
| 844 |
+
/// in the comment for the copyable_holder_caster.
|
| 845 |
+
template <typename type, typename holder_type, typename SFINAE = void>
|
| 846 |
+
struct move_only_holder_caster {
|
| 847 |
+
static_assert(std::is_base_of<type_caster_base<type>, type_caster<type>>::value,
|
| 848 |
+
"Holder classes are only supported for custom types");
|
| 849 |
+
|
| 850 |
+
static handle cast(holder_type &&src, return_value_policy, handle) {
|
| 851 |
+
auto *ptr = holder_helper<holder_type>::get(src);
|
| 852 |
+
return type_caster_base<type>::cast_holder(ptr, std::addressof(src));
|
| 853 |
+
}
|
| 854 |
+
static constexpr auto name = type_caster_base<type>::name;
|
| 855 |
+
};
|
| 856 |
+
|
| 857 |
+
template <typename type, typename deleter>
|
| 858 |
+
class type_caster<std::unique_ptr<type, deleter>>
|
| 859 |
+
: public move_only_holder_caster<type, std::unique_ptr<type, deleter>> {};
|
| 860 |
+
|
| 861 |
+
template <typename type, typename holder_type>
|
| 862 |
+
using type_caster_holder = conditional_t<is_copy_constructible<holder_type>::value,
|
| 863 |
+
copyable_holder_caster<type, holder_type>,
|
| 864 |
+
move_only_holder_caster<type, holder_type>>;
|
| 865 |
+
|
| 866 |
+
template <typename T, bool Value = false>
|
| 867 |
+
struct always_construct_holder {
|
| 868 |
+
static constexpr bool value = Value;
|
| 869 |
+
};
|
| 870 |
+
|
| 871 |
+
/// Create a specialization for custom holder types (silently ignores std::shared_ptr)
|
| 872 |
+
#define PYBIND11_DECLARE_HOLDER_TYPE(type, holder_type, ...) \
|
| 873 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) \
|
| 874 |
+
namespace detail { \
|
| 875 |
+
template <typename type> \
|
| 876 |
+
struct always_construct_holder<holder_type> : always_construct_holder<void, ##__VA_ARGS__> { \
|
| 877 |
+
}; \
|
| 878 |
+
template <typename type> \
|
| 879 |
+
class type_caster<holder_type, enable_if_t<!is_shared_ptr<holder_type>::value>> \
|
| 880 |
+
: public type_caster_holder<type, holder_type> {}; \
|
| 881 |
+
} \
|
| 882 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
| 883 |
+
|
| 884 |
+
// PYBIND11_DECLARE_HOLDER_TYPE holder types:
|
| 885 |
+
template <typename base, typename holder>
|
| 886 |
+
struct is_holder_type
|
| 887 |
+
: std::is_base_of<detail::type_caster_holder<base, holder>, detail::type_caster<holder>> {};
|
| 888 |
+
// Specialization for always-supported unique_ptr holders:
|
| 889 |
+
template <typename base, typename deleter>
|
| 890 |
+
struct is_holder_type<base, std::unique_ptr<base, deleter>> : std::true_type {};
|
| 891 |
+
|
| 892 |
+
#ifdef PYBIND11_DISABLE_HANDLE_TYPE_NAME_DEFAULT_IMPLEMENTATION // See PR #4888
|
| 893 |
+
|
| 894 |
+
// This leads to compilation errors if a specialization is missing.
|
| 895 |
+
template <typename T>
|
| 896 |
+
struct handle_type_name;
|
| 897 |
+
|
| 898 |
+
#else
|
| 899 |
+
|
| 900 |
+
template <typename T>
|
| 901 |
+
struct handle_type_name {
|
| 902 |
+
static constexpr auto name = const_name<T>();
|
| 903 |
+
};
|
| 904 |
+
|
| 905 |
+
#endif
|
| 906 |
+
|
| 907 |
+
template <>
|
| 908 |
+
struct handle_type_name<object> {
|
| 909 |
+
static constexpr auto name = const_name("object");
|
| 910 |
+
};
|
| 911 |
+
template <>
|
| 912 |
+
struct handle_type_name<list> {
|
| 913 |
+
static constexpr auto name = const_name("list");
|
| 914 |
+
};
|
| 915 |
+
template <>
|
| 916 |
+
struct handle_type_name<dict> {
|
| 917 |
+
static constexpr auto name = const_name("dict");
|
| 918 |
+
};
|
| 919 |
+
template <>
|
| 920 |
+
struct handle_type_name<anyset> {
|
| 921 |
+
static constexpr auto name = const_name("Union[set, frozenset]");
|
| 922 |
+
};
|
| 923 |
+
template <>
|
| 924 |
+
struct handle_type_name<set> {
|
| 925 |
+
static constexpr auto name = const_name("set");
|
| 926 |
+
};
|
| 927 |
+
template <>
|
| 928 |
+
struct handle_type_name<frozenset> {
|
| 929 |
+
static constexpr auto name = const_name("frozenset");
|
| 930 |
+
};
|
| 931 |
+
template <>
|
| 932 |
+
struct handle_type_name<str> {
|
| 933 |
+
static constexpr auto name = const_name("str");
|
| 934 |
+
};
|
| 935 |
+
template <>
|
| 936 |
+
struct handle_type_name<tuple> {
|
| 937 |
+
static constexpr auto name = const_name("tuple");
|
| 938 |
+
};
|
| 939 |
+
template <>
|
| 940 |
+
struct handle_type_name<bool_> {
|
| 941 |
+
static constexpr auto name = const_name("bool");
|
| 942 |
+
};
|
| 943 |
+
template <>
|
| 944 |
+
struct handle_type_name<bytes> {
|
| 945 |
+
static constexpr auto name = const_name(PYBIND11_BYTES_NAME);
|
| 946 |
+
};
|
| 947 |
+
template <>
|
| 948 |
+
struct handle_type_name<buffer> {
|
| 949 |
+
static constexpr auto name = const_name("Buffer");
|
| 950 |
+
};
|
| 951 |
+
template <>
|
| 952 |
+
struct handle_type_name<int_> {
|
| 953 |
+
static constexpr auto name = const_name("int");
|
| 954 |
+
};
|
| 955 |
+
template <>
|
| 956 |
+
struct handle_type_name<iterable> {
|
| 957 |
+
static constexpr auto name = const_name("Iterable");
|
| 958 |
+
};
|
| 959 |
+
template <>
|
| 960 |
+
struct handle_type_name<iterator> {
|
| 961 |
+
static constexpr auto name = const_name("Iterator");
|
| 962 |
+
};
|
| 963 |
+
template <>
|
| 964 |
+
struct handle_type_name<float_> {
|
| 965 |
+
static constexpr auto name = const_name("float");
|
| 966 |
+
};
|
| 967 |
+
template <>
|
| 968 |
+
struct handle_type_name<function> {
|
| 969 |
+
static constexpr auto name = const_name("Callable");
|
| 970 |
+
};
|
| 971 |
+
template <>
|
| 972 |
+
struct handle_type_name<handle> {
|
| 973 |
+
static constexpr auto name = handle_type_name<object>::name;
|
| 974 |
+
};
|
| 975 |
+
template <>
|
| 976 |
+
struct handle_type_name<none> {
|
| 977 |
+
static constexpr auto name = const_name("None");
|
| 978 |
+
};
|
| 979 |
+
template <>
|
| 980 |
+
struct handle_type_name<sequence> {
|
| 981 |
+
static constexpr auto name = const_name("Sequence");
|
| 982 |
+
};
|
| 983 |
+
template <>
|
| 984 |
+
struct handle_type_name<bytearray> {
|
| 985 |
+
static constexpr auto name = const_name("bytearray");
|
| 986 |
+
};
|
| 987 |
+
template <>
|
| 988 |
+
struct handle_type_name<memoryview> {
|
| 989 |
+
static constexpr auto name = const_name("memoryview");
|
| 990 |
+
};
|
| 991 |
+
template <>
|
| 992 |
+
struct handle_type_name<slice> {
|
| 993 |
+
static constexpr auto name = const_name("slice");
|
| 994 |
+
};
|
| 995 |
+
template <>
|
| 996 |
+
struct handle_type_name<type> {
|
| 997 |
+
static constexpr auto name = const_name("type");
|
| 998 |
+
};
|
| 999 |
+
template <>
|
| 1000 |
+
struct handle_type_name<capsule> {
|
| 1001 |
+
static constexpr auto name = const_name("capsule");
|
| 1002 |
+
};
|
| 1003 |
+
template <>
|
| 1004 |
+
struct handle_type_name<ellipsis> {
|
| 1005 |
+
static constexpr auto name = const_name("ellipsis");
|
| 1006 |
+
};
|
| 1007 |
+
template <>
|
| 1008 |
+
struct handle_type_name<weakref> {
|
| 1009 |
+
static constexpr auto name = const_name("weakref");
|
| 1010 |
+
};
|
| 1011 |
+
template <>
|
| 1012 |
+
struct handle_type_name<args> {
|
| 1013 |
+
static constexpr auto name = const_name("*args");
|
| 1014 |
+
};
|
| 1015 |
+
template <>
|
| 1016 |
+
struct handle_type_name<kwargs> {
|
| 1017 |
+
static constexpr auto name = const_name("**kwargs");
|
| 1018 |
+
};
|
| 1019 |
+
template <>
|
| 1020 |
+
struct handle_type_name<obj_attr_accessor> {
|
| 1021 |
+
static constexpr auto name = const_name<obj_attr_accessor>();
|
| 1022 |
+
};
|
| 1023 |
+
template <>
|
| 1024 |
+
struct handle_type_name<str_attr_accessor> {
|
| 1025 |
+
static constexpr auto name = const_name<str_attr_accessor>();
|
| 1026 |
+
};
|
| 1027 |
+
template <>
|
| 1028 |
+
struct handle_type_name<item_accessor> {
|
| 1029 |
+
static constexpr auto name = const_name<item_accessor>();
|
| 1030 |
+
};
|
| 1031 |
+
template <>
|
| 1032 |
+
struct handle_type_name<sequence_accessor> {
|
| 1033 |
+
static constexpr auto name = const_name<sequence_accessor>();
|
| 1034 |
+
};
|
| 1035 |
+
template <>
|
| 1036 |
+
struct handle_type_name<list_accessor> {
|
| 1037 |
+
static constexpr auto name = const_name<list_accessor>();
|
| 1038 |
+
};
|
| 1039 |
+
template <>
|
| 1040 |
+
struct handle_type_name<tuple_accessor> {
|
| 1041 |
+
static constexpr auto name = const_name<tuple_accessor>();
|
| 1042 |
+
};
|
| 1043 |
+
|
| 1044 |
+
template <typename type>
|
| 1045 |
+
struct pyobject_caster {
|
| 1046 |
+
template <typename T = type, enable_if_t<std::is_same<T, handle>::value, int> = 0>
|
| 1047 |
+
pyobject_caster() : value() {}
|
| 1048 |
+
|
| 1049 |
+
// `type` may not be default constructible (e.g. frozenset, anyset). Initializing `value`
|
| 1050 |
+
// to a nil handle is safe since it will only be accessed if `load` succeeds.
|
| 1051 |
+
template <typename T = type, enable_if_t<std::is_base_of<object, T>::value, int> = 0>
|
| 1052 |
+
pyobject_caster() : value(reinterpret_steal<type>(handle())) {}
|
| 1053 |
+
|
| 1054 |
+
template <typename T = type, enable_if_t<std::is_same<T, handle>::value, int> = 0>
|
| 1055 |
+
bool load(handle src, bool /* convert */) {
|
| 1056 |
+
value = src;
|
| 1057 |
+
return static_cast<bool>(value);
|
| 1058 |
+
}
|
| 1059 |
+
|
| 1060 |
+
template <typename T = type, enable_if_t<std::is_base_of<object, T>::value, int> = 0>
|
| 1061 |
+
bool load(handle src, bool /* convert */) {
|
| 1062 |
+
if (!isinstance<type>(src)) {
|
| 1063 |
+
return false;
|
| 1064 |
+
}
|
| 1065 |
+
value = reinterpret_borrow<type>(src);
|
| 1066 |
+
return true;
|
| 1067 |
+
}
|
| 1068 |
+
|
| 1069 |
+
static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
|
| 1070 |
+
return src.inc_ref();
|
| 1071 |
+
}
|
| 1072 |
+
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
|
| 1073 |
+
};
|
| 1074 |
+
|
| 1075 |
+
template <typename T>
|
| 1076 |
+
class type_caster<T, enable_if_t<is_pyobject<T>::value>> : public pyobject_caster<T> {};
|
| 1077 |
+
|
| 1078 |
+
// Our conditions for enabling moving are quite restrictive:
|
| 1079 |
+
// At compile time:
|
| 1080 |
+
// - T needs to be a non-const, non-pointer, non-reference type
|
| 1081 |
+
// - type_caster<T>::operator T&() must exist
|
| 1082 |
+
// - the type must be move constructible (obviously)
|
| 1083 |
+
// At run-time:
|
| 1084 |
+
// - if the type is non-copy-constructible, the object must be the sole owner of the type (i.e. it
|
| 1085 |
+
// must have ref_count() == 1)h
|
| 1086 |
+
// If any of the above are not satisfied, we fall back to copying.
|
| 1087 |
+
template <typename T>
|
| 1088 |
+
using move_is_plain_type
|
| 1089 |
+
= satisfies_none_of<T, std::is_void, std::is_pointer, std::is_reference, std::is_const>;
|
| 1090 |
+
template <typename T, typename SFINAE = void>
|
| 1091 |
+
struct move_always : std::false_type {};
|
| 1092 |
+
template <typename T>
|
| 1093 |
+
struct move_always<
|
| 1094 |
+
T,
|
| 1095 |
+
enable_if_t<
|
| 1096 |
+
all_of<move_is_plain_type<T>,
|
| 1097 |
+
negation<is_copy_constructible<T>>,
|
| 1098 |
+
is_move_constructible<T>,
|
| 1099 |
+
std::is_same<decltype(std::declval<make_caster<T>>().operator T &()), T &>>::value>>
|
| 1100 |
+
: std::true_type {};
|
| 1101 |
+
template <typename T, typename SFINAE = void>
|
| 1102 |
+
struct move_if_unreferenced : std::false_type {};
|
| 1103 |
+
template <typename T>
|
| 1104 |
+
struct move_if_unreferenced<
|
| 1105 |
+
T,
|
| 1106 |
+
enable_if_t<
|
| 1107 |
+
all_of<move_is_plain_type<T>,
|
| 1108 |
+
negation<move_always<T>>,
|
| 1109 |
+
is_move_constructible<T>,
|
| 1110 |
+
std::is_same<decltype(std::declval<make_caster<T>>().operator T &()), T &>>::value>>
|
| 1111 |
+
: std::true_type {};
|
| 1112 |
+
template <typename T>
|
| 1113 |
+
using move_never = none_of<move_always<T>, move_if_unreferenced<T>>;
|
| 1114 |
+
|
| 1115 |
+
// Detect whether returning a `type` from a cast on type's type_caster is going to result in a
|
| 1116 |
+
// reference or pointer to a local variable of the type_caster. Basically, only
|
| 1117 |
+
// non-reference/pointer `type`s and reference/pointers from a type_caster_generic are safe;
|
| 1118 |
+
// everything else returns a reference/pointer to a local variable.
|
| 1119 |
+
template <typename type>
|
| 1120 |
+
using cast_is_temporary_value_reference
|
| 1121 |
+
= bool_constant<(std::is_reference<type>::value || std::is_pointer<type>::value)
|
| 1122 |
+
&& !std::is_base_of<type_caster_generic, make_caster<type>>::value
|
| 1123 |
+
&& !std::is_same<intrinsic_t<type>, void>::value>;
|
| 1124 |
+
|
| 1125 |
+
// When a value returned from a C++ function is being cast back to Python, we almost always want to
|
| 1126 |
+
// force `policy = move`, regardless of the return value policy the function/method was declared
|
| 1127 |
+
// with.
|
| 1128 |
+
template <typename Return, typename SFINAE = void>
|
| 1129 |
+
struct return_value_policy_override {
|
| 1130 |
+
static return_value_policy policy(return_value_policy p) { return p; }
|
| 1131 |
+
};
|
| 1132 |
+
|
| 1133 |
+
template <typename Return>
|
| 1134 |
+
struct return_value_policy_override<
|
| 1135 |
+
Return,
|
| 1136 |
+
detail::enable_if_t<std::is_base_of<type_caster_generic, make_caster<Return>>::value, void>> {
|
| 1137 |
+
static return_value_policy policy(return_value_policy p) {
|
| 1138 |
+
return !std::is_lvalue_reference<Return>::value && !std::is_pointer<Return>::value
|
| 1139 |
+
? return_value_policy::move
|
| 1140 |
+
: p;
|
| 1141 |
+
}
|
| 1142 |
+
};
|
| 1143 |
+
|
| 1144 |
+
// Basic python -> C++ casting; throws if casting fails
|
| 1145 |
+
template <typename T, typename SFINAE>
|
| 1146 |
+
type_caster<T, SFINAE> &load_type(type_caster<T, SFINAE> &conv, const handle &handle) {
|
| 1147 |
+
static_assert(!detail::is_pyobject<T>::value,
|
| 1148 |
+
"Internal error: type_caster should only be used for C++ types");
|
| 1149 |
+
if (!conv.load(handle, true)) {
|
| 1150 |
+
#if !defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 1151 |
+
throw cast_error(
|
| 1152 |
+
"Unable to cast Python instance of type "
|
| 1153 |
+
+ str(type::handle_of(handle)).cast<std::string>()
|
| 1154 |
+
+ " to C++ type '?' (#define "
|
| 1155 |
+
"PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)");
|
| 1156 |
+
#else
|
| 1157 |
+
throw cast_error("Unable to cast Python instance of type "
|
| 1158 |
+
+ str(type::handle_of(handle)).cast<std::string>() + " to C++ type '"
|
| 1159 |
+
+ type_id<T>() + "'");
|
| 1160 |
+
#endif
|
| 1161 |
+
}
|
| 1162 |
+
return conv;
|
| 1163 |
+
}
|
| 1164 |
+
// Wrapper around the above that also constructs and returns a type_caster
|
| 1165 |
+
template <typename T>
|
| 1166 |
+
make_caster<T> load_type(const handle &handle) {
|
| 1167 |
+
make_caster<T> conv;
|
| 1168 |
+
load_type(conv, handle);
|
| 1169 |
+
return conv;
|
| 1170 |
+
}
|
| 1171 |
+
|
| 1172 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 1173 |
+
|
| 1174 |
+
// pytype -> C++ type
|
| 1175 |
+
template <typename T,
|
| 1176 |
+
detail::enable_if_t<!detail::is_pyobject<T>::value
|
| 1177 |
+
&& !detail::is_same_ignoring_cvref<T, PyObject *>::value,
|
| 1178 |
+
int>
|
| 1179 |
+
= 0>
|
| 1180 |
+
T cast(const handle &handle) {
|
| 1181 |
+
using namespace detail;
|
| 1182 |
+
static_assert(!cast_is_temporary_value_reference<T>::value,
|
| 1183 |
+
"Unable to cast type to reference: value is local to type caster");
|
| 1184 |
+
return cast_op<T>(load_type<T>(handle));
|
| 1185 |
+
}
|
| 1186 |
+
|
| 1187 |
+
// pytype -> pytype (calls converting constructor)
|
| 1188 |
+
template <typename T, detail::enable_if_t<detail::is_pyobject<T>::value, int> = 0>
|
| 1189 |
+
T cast(const handle &handle) {
|
| 1190 |
+
return T(reinterpret_borrow<object>(handle));
|
| 1191 |
+
}
|
| 1192 |
+
|
| 1193 |
+
// Note that `cast<PyObject *>(obj)` increments the reference count of `obj`.
|
| 1194 |
+
// This is necessary for the case that `obj` is a temporary, and could
|
| 1195 |
+
// not possibly be different, given
|
| 1196 |
+
// 1. the established convention that the passed `handle` is borrowed, and
|
| 1197 |
+
// 2. we don't want to force all generic code using `cast<T>()` to special-case
|
| 1198 |
+
// handling of `T` = `PyObject *` (to increment the reference count there).
|
| 1199 |
+
// It is the responsibility of the caller to ensure that the reference count
|
| 1200 |
+
// is decremented.
|
| 1201 |
+
template <typename T,
|
| 1202 |
+
typename Handle,
|
| 1203 |
+
detail::enable_if_t<detail::is_same_ignoring_cvref<T, PyObject *>::value
|
| 1204 |
+
&& detail::is_same_ignoring_cvref<Handle, handle>::value,
|
| 1205 |
+
int>
|
| 1206 |
+
= 0>
|
| 1207 |
+
T cast(Handle &&handle) {
|
| 1208 |
+
return handle.inc_ref().ptr();
|
| 1209 |
+
}
|
| 1210 |
+
// To optimize way an inc_ref/dec_ref cycle:
|
| 1211 |
+
template <typename T,
|
| 1212 |
+
typename Object,
|
| 1213 |
+
detail::enable_if_t<detail::is_same_ignoring_cvref<T, PyObject *>::value
|
| 1214 |
+
&& detail::is_same_ignoring_cvref<Object, object>::value,
|
| 1215 |
+
int>
|
| 1216 |
+
= 0>
|
| 1217 |
+
T cast(Object &&obj) {
|
| 1218 |
+
return obj.release().ptr();
|
| 1219 |
+
}
|
| 1220 |
+
|
| 1221 |
+
// C++ type -> py::object
|
| 1222 |
+
template <typename T, detail::enable_if_t<!detail::is_pyobject<T>::value, int> = 0>
|
| 1223 |
+
object cast(T &&value,
|
| 1224 |
+
return_value_policy policy = return_value_policy::automatic_reference,
|
| 1225 |
+
handle parent = handle()) {
|
| 1226 |
+
using no_ref_T = typename std::remove_reference<T>::type;
|
| 1227 |
+
if (policy == return_value_policy::automatic) {
|
| 1228 |
+
policy = std::is_pointer<no_ref_T>::value ? return_value_policy::take_ownership
|
| 1229 |
+
: std::is_lvalue_reference<T>::value ? return_value_policy::copy
|
| 1230 |
+
: return_value_policy::move;
|
| 1231 |
+
} else if (policy == return_value_policy::automatic_reference) {
|
| 1232 |
+
policy = std::is_pointer<no_ref_T>::value ? return_value_policy::reference
|
| 1233 |
+
: std::is_lvalue_reference<T>::value ? return_value_policy::copy
|
| 1234 |
+
: return_value_policy::move;
|
| 1235 |
+
}
|
| 1236 |
+
return reinterpret_steal<object>(
|
| 1237 |
+
detail::make_caster<T>::cast(std::forward<T>(value), policy, parent));
|
| 1238 |
+
}
|
| 1239 |
+
|
| 1240 |
+
template <typename T>
|
| 1241 |
+
T handle::cast() const {
|
| 1242 |
+
return pybind11::cast<T>(*this);
|
| 1243 |
+
}
|
| 1244 |
+
template <>
|
| 1245 |
+
inline void handle::cast() const {
|
| 1246 |
+
return;
|
| 1247 |
+
}
|
| 1248 |
+
|
| 1249 |
+
template <typename T>
|
| 1250 |
+
detail::enable_if_t<!detail::move_never<T>::value, T> move(object &&obj) {
|
| 1251 |
+
if (obj.ref_count() > 1) {
|
| 1252 |
+
#if !defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 1253 |
+
throw cast_error(
|
| 1254 |
+
"Unable to cast Python " + str(type::handle_of(obj)).cast<std::string>()
|
| 1255 |
+
+ " instance to C++ rvalue: instance has multiple references"
|
| 1256 |
+
" (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)");
|
| 1257 |
+
#else
|
| 1258 |
+
throw cast_error("Unable to move from Python "
|
| 1259 |
+
+ str(type::handle_of(obj)).cast<std::string>() + " instance to C++ "
|
| 1260 |
+
+ type_id<T>() + " instance: instance has multiple references");
|
| 1261 |
+
#endif
|
| 1262 |
+
}
|
| 1263 |
+
|
| 1264 |
+
// Move into a temporary and return that, because the reference may be a local value of `conv`
|
| 1265 |
+
T ret = std::move(detail::load_type<T>(obj).operator T &());
|
| 1266 |
+
return ret;
|
| 1267 |
+
}
|
| 1268 |
+
|
| 1269 |
+
// Calling cast() on an rvalue calls pybind11::cast with the object rvalue, which does:
|
| 1270 |
+
// - If we have to move (because T has no copy constructor), do it. This will fail if the moved
|
| 1271 |
+
// object has multiple references, but trying to copy will fail to compile.
|
| 1272 |
+
// - If both movable and copyable, check ref count: if 1, move; otherwise copy
|
| 1273 |
+
// - Otherwise (not movable), copy.
|
| 1274 |
+
template <typename T>
|
| 1275 |
+
detail::enable_if_t<!detail::is_pyobject<T>::value && detail::move_always<T>::value, T>
|
| 1276 |
+
cast(object &&object) {
|
| 1277 |
+
return move<T>(std::move(object));
|
| 1278 |
+
}
|
| 1279 |
+
template <typename T>
|
| 1280 |
+
detail::enable_if_t<!detail::is_pyobject<T>::value && detail::move_if_unreferenced<T>::value, T>
|
| 1281 |
+
cast(object &&object) {
|
| 1282 |
+
if (object.ref_count() > 1) {
|
| 1283 |
+
return cast<T>(object);
|
| 1284 |
+
}
|
| 1285 |
+
return move<T>(std::move(object));
|
| 1286 |
+
}
|
| 1287 |
+
template <typename T>
|
| 1288 |
+
detail::enable_if_t<!detail::is_pyobject<T>::value && detail::move_never<T>::value, T>
|
| 1289 |
+
cast(object &&object) {
|
| 1290 |
+
return cast<T>(object);
|
| 1291 |
+
}
|
| 1292 |
+
|
| 1293 |
+
// pytype rvalue -> pytype (calls converting constructor)
|
| 1294 |
+
template <typename T>
|
| 1295 |
+
detail::enable_if_t<detail::is_pyobject<T>::value, T> cast(object &&object) {
|
| 1296 |
+
return T(std::move(object));
|
| 1297 |
+
}
|
| 1298 |
+
|
| 1299 |
+
template <typename T>
|
| 1300 |
+
T object::cast() const & {
|
| 1301 |
+
return pybind11::cast<T>(*this);
|
| 1302 |
+
}
|
| 1303 |
+
template <typename T>
|
| 1304 |
+
T object::cast() && {
|
| 1305 |
+
return pybind11::cast<T>(std::move(*this));
|
| 1306 |
+
}
|
| 1307 |
+
template <>
|
| 1308 |
+
inline void object::cast() const & {
|
| 1309 |
+
return;
|
| 1310 |
+
}
|
| 1311 |
+
template <>
|
| 1312 |
+
inline void object::cast() && {
|
| 1313 |
+
return;
|
| 1314 |
+
}
|
| 1315 |
+
|
| 1316 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 1317 |
+
|
| 1318 |
+
// Declared in pytypes.h:
|
| 1319 |
+
template <typename T, enable_if_t<!is_pyobject<T>::value, int>>
|
| 1320 |
+
object object_or_cast(T &&o) {
|
| 1321 |
+
return pybind11::cast(std::forward<T>(o));
|
| 1322 |
+
}
|
| 1323 |
+
|
| 1324 |
+
// Placeholder type for the unneeded (and dead code) static variable in the
|
| 1325 |
+
// PYBIND11_OVERRIDE_OVERRIDE macro
|
| 1326 |
+
struct override_unused {};
|
| 1327 |
+
template <typename ret_type>
|
| 1328 |
+
using override_caster_t = conditional_t<cast_is_temporary_value_reference<ret_type>::value,
|
| 1329 |
+
make_caster<ret_type>,
|
| 1330 |
+
override_unused>;
|
| 1331 |
+
|
| 1332 |
+
// Trampoline use: for reference/pointer types to value-converted values, we do a value cast, then
|
| 1333 |
+
// store the result in the given variable. For other types, this is a no-op.
|
| 1334 |
+
template <typename T>
|
| 1335 |
+
enable_if_t<cast_is_temporary_value_reference<T>::value, T> cast_ref(object &&o,
|
| 1336 |
+
make_caster<T> &caster) {
|
| 1337 |
+
return cast_op<T>(load_type(caster, o));
|
| 1338 |
+
}
|
| 1339 |
+
template <typename T>
|
| 1340 |
+
enable_if_t<!cast_is_temporary_value_reference<T>::value, T> cast_ref(object &&,
|
| 1341 |
+
override_unused &) {
|
| 1342 |
+
pybind11_fail("Internal error: cast_ref fallback invoked");
|
| 1343 |
+
}
|
| 1344 |
+
|
| 1345 |
+
// Trampoline use: Having a pybind11::cast with an invalid reference type is going to
|
| 1346 |
+
// static_assert, even though if it's in dead code, so we provide a "trampoline" to pybind11::cast
|
| 1347 |
+
// that only does anything in cases where pybind11::cast is valid.
|
| 1348 |
+
template <typename T>
|
| 1349 |
+
enable_if_t<cast_is_temporary_value_reference<T>::value
|
| 1350 |
+
&& !detail::is_same_ignoring_cvref<T, PyObject *>::value,
|
| 1351 |
+
T>
|
| 1352 |
+
cast_safe(object &&) {
|
| 1353 |
+
pybind11_fail("Internal error: cast_safe fallback invoked");
|
| 1354 |
+
}
|
| 1355 |
+
template <typename T>
|
| 1356 |
+
enable_if_t<std::is_void<T>::value, void> cast_safe(object &&) {}
|
| 1357 |
+
template <typename T>
|
| 1358 |
+
enable_if_t<detail::is_same_ignoring_cvref<T, PyObject *>::value, PyObject *>
|
| 1359 |
+
cast_safe(object &&o) {
|
| 1360 |
+
return o.release().ptr();
|
| 1361 |
+
}
|
| 1362 |
+
template <typename T>
|
| 1363 |
+
enable_if_t<detail::none_of<cast_is_temporary_value_reference<T>,
|
| 1364 |
+
detail::is_same_ignoring_cvref<T, PyObject *>,
|
| 1365 |
+
std::is_void<T>>::value,
|
| 1366 |
+
T>
|
| 1367 |
+
cast_safe(object &&o) {
|
| 1368 |
+
return pybind11::cast<T>(std::move(o));
|
| 1369 |
+
}
|
| 1370 |
+
|
| 1371 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 1372 |
+
|
| 1373 |
+
// The overloads could coexist, i.e. the #if is not strictly speaking needed,
|
| 1374 |
+
// but it is an easy minor optimization.
|
| 1375 |
+
#if !defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 1376 |
+
inline cast_error cast_error_unable_to_convert_call_arg(const std::string &name) {
|
| 1377 |
+
return cast_error("Unable to convert call argument '" + name
|
| 1378 |
+
+ "' to Python object (#define "
|
| 1379 |
+
"PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)");
|
| 1380 |
+
}
|
| 1381 |
+
#else
|
| 1382 |
+
inline cast_error cast_error_unable_to_convert_call_arg(const std::string &name,
|
| 1383 |
+
const std::string &type) {
|
| 1384 |
+
return cast_error("Unable to convert call argument '" + name + "' of type '" + type
|
| 1385 |
+
+ "' to Python object");
|
| 1386 |
+
}
|
| 1387 |
+
#endif
|
| 1388 |
+
|
| 1389 |
+
template <return_value_policy policy = return_value_policy::automatic_reference>
|
| 1390 |
+
tuple make_tuple() {
|
| 1391 |
+
return tuple(0);
|
| 1392 |
+
}
|
| 1393 |
+
|
| 1394 |
+
template <return_value_policy policy = return_value_policy::automatic_reference, typename... Args>
|
| 1395 |
+
tuple make_tuple(Args &&...args_) {
|
| 1396 |
+
constexpr size_t size = sizeof...(Args);
|
| 1397 |
+
std::array<object, size> args{{reinterpret_steal<object>(
|
| 1398 |
+
detail::make_caster<Args>::cast(std::forward<Args>(args_), policy, nullptr))...}};
|
| 1399 |
+
for (size_t i = 0; i < args.size(); i++) {
|
| 1400 |
+
if (!args[i]) {
|
| 1401 |
+
#if !defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 1402 |
+
throw cast_error_unable_to_convert_call_arg(std::to_string(i));
|
| 1403 |
+
#else
|
| 1404 |
+
std::array<std::string, size> argtypes{{type_id<Args>()...}};
|
| 1405 |
+
throw cast_error_unable_to_convert_call_arg(std::to_string(i), argtypes[i]);
|
| 1406 |
+
#endif
|
| 1407 |
+
}
|
| 1408 |
+
}
|
| 1409 |
+
tuple result(size);
|
| 1410 |
+
int counter = 0;
|
| 1411 |
+
for (auto &arg_value : args) {
|
| 1412 |
+
PyTuple_SET_ITEM(result.ptr(), counter++, arg_value.release().ptr());
|
| 1413 |
+
}
|
| 1414 |
+
return result;
|
| 1415 |
+
}
|
| 1416 |
+
|
| 1417 |
+
/// \ingroup annotations
|
| 1418 |
+
/// Annotation for arguments
|
| 1419 |
+
struct arg {
|
| 1420 |
+
/// Constructs an argument with the name of the argument; if null or omitted, this is a
|
| 1421 |
+
/// positional argument.
|
| 1422 |
+
constexpr explicit arg(const char *name = nullptr)
|
| 1423 |
+
: name(name), flag_noconvert(false), flag_none(true) {}
|
| 1424 |
+
/// Assign a value to this argument
|
| 1425 |
+
template <typename T>
|
| 1426 |
+
arg_v operator=(T &&value) const;
|
| 1427 |
+
/// Indicate that the type should not be converted in the type caster
|
| 1428 |
+
arg &noconvert(bool flag = true) {
|
| 1429 |
+
flag_noconvert = flag;
|
| 1430 |
+
return *this;
|
| 1431 |
+
}
|
| 1432 |
+
/// Indicates that the argument should/shouldn't allow None (e.g. for nullable pointer args)
|
| 1433 |
+
arg &none(bool flag = true) {
|
| 1434 |
+
flag_none = flag;
|
| 1435 |
+
return *this;
|
| 1436 |
+
}
|
| 1437 |
+
|
| 1438 |
+
const char *name; ///< If non-null, this is a named kwargs argument
|
| 1439 |
+
bool flag_noconvert : 1; ///< If set, do not allow conversion (requires a supporting type
|
| 1440 |
+
///< caster!)
|
| 1441 |
+
bool flag_none : 1; ///< If set (the default), allow None to be passed to this argument
|
| 1442 |
+
};
|
| 1443 |
+
|
| 1444 |
+
/// \ingroup annotations
|
| 1445 |
+
/// Annotation for arguments with values
|
| 1446 |
+
struct arg_v : arg {
|
| 1447 |
+
private:
|
| 1448 |
+
template <typename T>
|
| 1449 |
+
arg_v(arg &&base, T &&x, const char *descr = nullptr)
|
| 1450 |
+
: arg(base), value(reinterpret_steal<object>(detail::make_caster<T>::cast(
|
| 1451 |
+
std::forward<T>(x), return_value_policy::automatic, {}))),
|
| 1452 |
+
descr(descr)
|
| 1453 |
+
#if defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 1454 |
+
,
|
| 1455 |
+
type(type_id<T>())
|
| 1456 |
+
#endif
|
| 1457 |
+
{
|
| 1458 |
+
// Workaround! See:
|
| 1459 |
+
// https://github.com/pybind/pybind11/issues/2336
|
| 1460 |
+
// https://github.com/pybind/pybind11/pull/2685#issuecomment-731286700
|
| 1461 |
+
if (PyErr_Occurred()) {
|
| 1462 |
+
PyErr_Clear();
|
| 1463 |
+
}
|
| 1464 |
+
}
|
| 1465 |
+
|
| 1466 |
+
public:
|
| 1467 |
+
/// Direct construction with name, default, and description
|
| 1468 |
+
template <typename T>
|
| 1469 |
+
arg_v(const char *name, T &&x, const char *descr = nullptr)
|
| 1470 |
+
: arg_v(arg(name), std::forward<T>(x), descr) {}
|
| 1471 |
+
|
| 1472 |
+
/// Called internally when invoking `py::arg("a") = value`
|
| 1473 |
+
template <typename T>
|
| 1474 |
+
arg_v(const arg &base, T &&x, const char *descr = nullptr)
|
| 1475 |
+
: arg_v(arg(base), std::forward<T>(x), descr) {}
|
| 1476 |
+
|
| 1477 |
+
/// Same as `arg::noconvert()`, but returns *this as arg_v&, not arg&
|
| 1478 |
+
arg_v &noconvert(bool flag = true) {
|
| 1479 |
+
arg::noconvert(flag);
|
| 1480 |
+
return *this;
|
| 1481 |
+
}
|
| 1482 |
+
|
| 1483 |
+
/// Same as `arg::nonone()`, but returns *this as arg_v&, not arg&
|
| 1484 |
+
arg_v &none(bool flag = true) {
|
| 1485 |
+
arg::none(flag);
|
| 1486 |
+
return *this;
|
| 1487 |
+
}
|
| 1488 |
+
|
| 1489 |
+
/// The default value
|
| 1490 |
+
object value;
|
| 1491 |
+
/// The (optional) description of the default value
|
| 1492 |
+
const char *descr;
|
| 1493 |
+
#if defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 1494 |
+
/// The C++ type name of the default value (only available when compiled in debug mode)
|
| 1495 |
+
std::string type;
|
| 1496 |
+
#endif
|
| 1497 |
+
};
|
| 1498 |
+
|
| 1499 |
+
/// \ingroup annotations
|
| 1500 |
+
/// Annotation indicating that all following arguments are keyword-only; the is the equivalent of
|
| 1501 |
+
/// an unnamed '*' argument
|
| 1502 |
+
struct kw_only {};
|
| 1503 |
+
|
| 1504 |
+
/// \ingroup annotations
|
| 1505 |
+
/// Annotation indicating that all previous arguments are positional-only; the is the equivalent of
|
| 1506 |
+
/// an unnamed '/' argument (in Python 3.8)
|
| 1507 |
+
struct pos_only {};
|
| 1508 |
+
|
| 1509 |
+
template <typename T>
|
| 1510 |
+
arg_v arg::operator=(T &&value) const {
|
| 1511 |
+
return {*this, std::forward<T>(value)};
|
| 1512 |
+
}
|
| 1513 |
+
|
| 1514 |
+
/// Alias for backward compatibility -- to be removed in version 2.0
|
| 1515 |
+
template <typename /*unused*/>
|
| 1516 |
+
using arg_t = arg_v;
|
| 1517 |
+
|
| 1518 |
+
inline namespace literals {
|
| 1519 |
+
/** \rst
|
| 1520 |
+
String literal version of `arg`
|
| 1521 |
+
\endrst */
|
| 1522 |
+
constexpr arg
|
| 1523 |
+
#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ < 5
|
| 1524 |
+
operator"" _a // gcc 4.8.5 insists on having a space (hard error).
|
| 1525 |
+
#else
|
| 1526 |
+
operator""_a // clang 17 generates a deprecation warning if there is a space.
|
| 1527 |
+
#endif
|
| 1528 |
+
(const char *name, size_t) {
|
| 1529 |
+
return arg(name);
|
| 1530 |
+
}
|
| 1531 |
+
} // namespace literals
|
| 1532 |
+
|
| 1533 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 1534 |
+
|
| 1535 |
+
template <typename T>
|
| 1536 |
+
using is_kw_only = std::is_same<intrinsic_t<T>, kw_only>;
|
| 1537 |
+
template <typename T>
|
| 1538 |
+
using is_pos_only = std::is_same<intrinsic_t<T>, pos_only>;
|
| 1539 |
+
|
| 1540 |
+
// forward declaration (definition in attr.h)
|
| 1541 |
+
struct function_record;
|
| 1542 |
+
|
| 1543 |
+
/// Internal data associated with a single function call
|
| 1544 |
+
struct function_call {
|
| 1545 |
+
function_call(const function_record &f, handle p); // Implementation in attr.h
|
| 1546 |
+
|
| 1547 |
+
/// The function data:
|
| 1548 |
+
const function_record &func;
|
| 1549 |
+
|
| 1550 |
+
/// Arguments passed to the function:
|
| 1551 |
+
std::vector<handle> args;
|
| 1552 |
+
|
| 1553 |
+
/// The `convert` value the arguments should be loaded with
|
| 1554 |
+
std::vector<bool> args_convert;
|
| 1555 |
+
|
| 1556 |
+
/// Extra references for the optional `py::args` and/or `py::kwargs` arguments (which, if
|
| 1557 |
+
/// present, are also in `args` but without a reference).
|
| 1558 |
+
object args_ref, kwargs_ref;
|
| 1559 |
+
|
| 1560 |
+
/// The parent, if any
|
| 1561 |
+
handle parent;
|
| 1562 |
+
|
| 1563 |
+
/// If this is a call to an initializer, this argument contains `self`
|
| 1564 |
+
handle init_self;
|
| 1565 |
+
};
|
| 1566 |
+
|
| 1567 |
+
/// Helper class which loads arguments for C++ functions called from Python
|
| 1568 |
+
template <typename... Args>
|
| 1569 |
+
class argument_loader {
|
| 1570 |
+
using indices = make_index_sequence<sizeof...(Args)>;
|
| 1571 |
+
|
| 1572 |
+
template <typename Arg>
|
| 1573 |
+
using argument_is_args = std::is_same<intrinsic_t<Arg>, args>;
|
| 1574 |
+
template <typename Arg>
|
| 1575 |
+
using argument_is_kwargs = std::is_same<intrinsic_t<Arg>, kwargs>;
|
| 1576 |
+
// Get kwargs argument position, or -1 if not present:
|
| 1577 |
+
static constexpr auto kwargs_pos = constexpr_last<argument_is_kwargs, Args...>();
|
| 1578 |
+
|
| 1579 |
+
static_assert(kwargs_pos == -1 || kwargs_pos == (int) sizeof...(Args) - 1,
|
| 1580 |
+
"py::kwargs is only permitted as the last argument of a function");
|
| 1581 |
+
|
| 1582 |
+
public:
|
| 1583 |
+
static constexpr bool has_kwargs = kwargs_pos != -1;
|
| 1584 |
+
|
| 1585 |
+
// py::args argument position; -1 if not present.
|
| 1586 |
+
static constexpr int args_pos = constexpr_last<argument_is_args, Args...>();
|
| 1587 |
+
|
| 1588 |
+
static_assert(args_pos == -1 || args_pos == constexpr_first<argument_is_args, Args...>(),
|
| 1589 |
+
"py::args cannot be specified more than once");
|
| 1590 |
+
|
| 1591 |
+
static constexpr auto arg_names
|
| 1592 |
+
= ::pybind11::detail::concat(type_descr(make_caster<Args>::name)...);
|
| 1593 |
+
|
| 1594 |
+
bool load_args(function_call &call) { return load_impl_sequence(call, indices{}); }
|
| 1595 |
+
|
| 1596 |
+
template <typename Return, typename Guard, typename Func>
|
| 1597 |
+
// NOLINTNEXTLINE(readability-const-return-type)
|
| 1598 |
+
enable_if_t<!std::is_void<Return>::value, Return> call(Func &&f) && {
|
| 1599 |
+
return std::move(*this).template call_impl<remove_cv_t<Return>>(
|
| 1600 |
+
std::forward<Func>(f), indices{}, Guard{});
|
| 1601 |
+
}
|
| 1602 |
+
|
| 1603 |
+
template <typename Return, typename Guard, typename Func>
|
| 1604 |
+
enable_if_t<std::is_void<Return>::value, void_type> call(Func &&f) && {
|
| 1605 |
+
std::move(*this).template call_impl<remove_cv_t<Return>>(
|
| 1606 |
+
std::forward<Func>(f), indices{}, Guard{});
|
| 1607 |
+
return void_type();
|
| 1608 |
+
}
|
| 1609 |
+
|
| 1610 |
+
private:
|
| 1611 |
+
static bool load_impl_sequence(function_call &, index_sequence<>) { return true; }
|
| 1612 |
+
|
| 1613 |
+
template <size_t... Is>
|
| 1614 |
+
bool load_impl_sequence(function_call &call, index_sequence<Is...>) {
|
| 1615 |
+
#ifdef __cpp_fold_expressions
|
| 1616 |
+
if ((... || !std::get<Is>(argcasters).load(call.args[Is], call.args_convert[Is]))) {
|
| 1617 |
+
return false;
|
| 1618 |
+
}
|
| 1619 |
+
#else
|
| 1620 |
+
for (bool r : {std::get<Is>(argcasters).load(call.args[Is], call.args_convert[Is])...}) {
|
| 1621 |
+
if (!r) {
|
| 1622 |
+
return false;
|
| 1623 |
+
}
|
| 1624 |
+
}
|
| 1625 |
+
#endif
|
| 1626 |
+
return true;
|
| 1627 |
+
}
|
| 1628 |
+
|
| 1629 |
+
template <typename Return, typename Func, size_t... Is, typename Guard>
|
| 1630 |
+
Return call_impl(Func &&f, index_sequence<Is...>, Guard &&) && {
|
| 1631 |
+
return std::forward<Func>(f)(cast_op<Args>(std::move(std::get<Is>(argcasters)))...);
|
| 1632 |
+
}
|
| 1633 |
+
|
| 1634 |
+
std::tuple<make_caster<Args>...> argcasters;
|
| 1635 |
+
};
|
| 1636 |
+
|
| 1637 |
+
/// Helper class which collects only positional arguments for a Python function call.
|
| 1638 |
+
/// A fancier version below can collect any argument, but this one is optimal for simple calls.
|
| 1639 |
+
template <return_value_policy policy>
|
| 1640 |
+
class simple_collector {
|
| 1641 |
+
public:
|
| 1642 |
+
template <typename... Ts>
|
| 1643 |
+
explicit simple_collector(Ts &&...values)
|
| 1644 |
+
: m_args(pybind11::make_tuple<policy>(std::forward<Ts>(values)...)) {}
|
| 1645 |
+
|
| 1646 |
+
const tuple &args() const & { return m_args; }
|
| 1647 |
+
dict kwargs() const { return {}; }
|
| 1648 |
+
|
| 1649 |
+
tuple args() && { return std::move(m_args); }
|
| 1650 |
+
|
| 1651 |
+
/// Call a Python function and pass the collected arguments
|
| 1652 |
+
object call(PyObject *ptr) const {
|
| 1653 |
+
PyObject *result = PyObject_CallObject(ptr, m_args.ptr());
|
| 1654 |
+
if (!result) {
|
| 1655 |
+
throw error_already_set();
|
| 1656 |
+
}
|
| 1657 |
+
return reinterpret_steal<object>(result);
|
| 1658 |
+
}
|
| 1659 |
+
|
| 1660 |
+
private:
|
| 1661 |
+
tuple m_args;
|
| 1662 |
+
};
|
| 1663 |
+
|
| 1664 |
+
/// Helper class which collects positional, keyword, * and ** arguments for a Python function call
|
| 1665 |
+
template <return_value_policy policy>
|
| 1666 |
+
class unpacking_collector {
|
| 1667 |
+
public:
|
| 1668 |
+
template <typename... Ts>
|
| 1669 |
+
explicit unpacking_collector(Ts &&...values) {
|
| 1670 |
+
// Tuples aren't (easily) resizable so a list is needed for collection,
|
| 1671 |
+
// but the actual function call strictly requires a tuple.
|
| 1672 |
+
auto args_list = list();
|
| 1673 |
+
using expander = int[];
|
| 1674 |
+
(void) expander{0, (process(args_list, std::forward<Ts>(values)), 0)...};
|
| 1675 |
+
|
| 1676 |
+
m_args = std::move(args_list);
|
| 1677 |
+
}
|
| 1678 |
+
|
| 1679 |
+
const tuple &args() const & { return m_args; }
|
| 1680 |
+
const dict &kwargs() const & { return m_kwargs; }
|
| 1681 |
+
|
| 1682 |
+
tuple args() && { return std::move(m_args); }
|
| 1683 |
+
dict kwargs() && { return std::move(m_kwargs); }
|
| 1684 |
+
|
| 1685 |
+
/// Call a Python function and pass the collected arguments
|
| 1686 |
+
object call(PyObject *ptr) const {
|
| 1687 |
+
PyObject *result = PyObject_Call(ptr, m_args.ptr(), m_kwargs.ptr());
|
| 1688 |
+
if (!result) {
|
| 1689 |
+
throw error_already_set();
|
| 1690 |
+
}
|
| 1691 |
+
return reinterpret_steal<object>(result);
|
| 1692 |
+
}
|
| 1693 |
+
|
| 1694 |
+
private:
|
| 1695 |
+
template <typename T>
|
| 1696 |
+
void process(list &args_list, T &&x) {
|
| 1697 |
+
auto o = reinterpret_steal<object>(
|
| 1698 |
+
detail::make_caster<T>::cast(std::forward<T>(x), policy, {}));
|
| 1699 |
+
if (!o) {
|
| 1700 |
+
#if !defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 1701 |
+
throw cast_error_unable_to_convert_call_arg(std::to_string(args_list.size()));
|
| 1702 |
+
#else
|
| 1703 |
+
throw cast_error_unable_to_convert_call_arg(std::to_string(args_list.size()),
|
| 1704 |
+
type_id<T>());
|
| 1705 |
+
#endif
|
| 1706 |
+
}
|
| 1707 |
+
args_list.append(std::move(o));
|
| 1708 |
+
}
|
| 1709 |
+
|
| 1710 |
+
void process(list &args_list, detail::args_proxy ap) {
|
| 1711 |
+
for (auto a : ap) {
|
| 1712 |
+
args_list.append(a);
|
| 1713 |
+
}
|
| 1714 |
+
}
|
| 1715 |
+
|
| 1716 |
+
void process(list & /*args_list*/, arg_v a) {
|
| 1717 |
+
if (!a.name) {
|
| 1718 |
+
#if !defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 1719 |
+
nameless_argument_error();
|
| 1720 |
+
#else
|
| 1721 |
+
nameless_argument_error(a.type);
|
| 1722 |
+
#endif
|
| 1723 |
+
}
|
| 1724 |
+
if (m_kwargs.contains(a.name)) {
|
| 1725 |
+
#if !defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 1726 |
+
multiple_values_error();
|
| 1727 |
+
#else
|
| 1728 |
+
multiple_values_error(a.name);
|
| 1729 |
+
#endif
|
| 1730 |
+
}
|
| 1731 |
+
if (!a.value) {
|
| 1732 |
+
#if !defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 1733 |
+
throw cast_error_unable_to_convert_call_arg(a.name);
|
| 1734 |
+
#else
|
| 1735 |
+
throw cast_error_unable_to_convert_call_arg(a.name, a.type);
|
| 1736 |
+
#endif
|
| 1737 |
+
}
|
| 1738 |
+
m_kwargs[a.name] = std::move(a.value);
|
| 1739 |
+
}
|
| 1740 |
+
|
| 1741 |
+
void process(list & /*args_list*/, detail::kwargs_proxy kp) {
|
| 1742 |
+
if (!kp) {
|
| 1743 |
+
return;
|
| 1744 |
+
}
|
| 1745 |
+
for (auto k : reinterpret_borrow<dict>(kp)) {
|
| 1746 |
+
if (m_kwargs.contains(k.first)) {
|
| 1747 |
+
#if !defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 1748 |
+
multiple_values_error();
|
| 1749 |
+
#else
|
| 1750 |
+
multiple_values_error(str(k.first));
|
| 1751 |
+
#endif
|
| 1752 |
+
}
|
| 1753 |
+
m_kwargs[k.first] = k.second;
|
| 1754 |
+
}
|
| 1755 |
+
}
|
| 1756 |
+
|
| 1757 |
+
[[noreturn]] static void nameless_argument_error() {
|
| 1758 |
+
throw type_error(
|
| 1759 |
+
"Got kwargs without a name; only named arguments "
|
| 1760 |
+
"may be passed via py::arg() to a python function call. "
|
| 1761 |
+
"(#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)");
|
| 1762 |
+
}
|
| 1763 |
+
[[noreturn]] static void nameless_argument_error(const std::string &type) {
|
| 1764 |
+
throw type_error("Got kwargs without a name of type '" + type
|
| 1765 |
+
+ "'; only named "
|
| 1766 |
+
"arguments may be passed via py::arg() to a python function call. ");
|
| 1767 |
+
}
|
| 1768 |
+
[[noreturn]] static void multiple_values_error() {
|
| 1769 |
+
throw type_error(
|
| 1770 |
+
"Got multiple values for keyword argument "
|
| 1771 |
+
"(#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)");
|
| 1772 |
+
}
|
| 1773 |
+
|
| 1774 |
+
[[noreturn]] static void multiple_values_error(const std::string &name) {
|
| 1775 |
+
throw type_error("Got multiple values for keyword argument '" + name + "'");
|
| 1776 |
+
}
|
| 1777 |
+
|
| 1778 |
+
private:
|
| 1779 |
+
tuple m_args;
|
| 1780 |
+
dict m_kwargs;
|
| 1781 |
+
};
|
| 1782 |
+
|
| 1783 |
+
// [workaround(intel)] Separate function required here
|
| 1784 |
+
// We need to put this into a separate function because the Intel compiler
|
| 1785 |
+
// fails to compile enable_if_t<!all_of<is_positional<Args>...>::value>
|
| 1786 |
+
// (tested with ICC 2021.1 Beta 20200827).
|
| 1787 |
+
template <typename... Args>
|
| 1788 |
+
constexpr bool args_are_all_positional() {
|
| 1789 |
+
return all_of<is_positional<Args>...>::value;
|
| 1790 |
+
}
|
| 1791 |
+
|
| 1792 |
+
/// Collect only positional arguments for a Python function call
|
| 1793 |
+
template <return_value_policy policy,
|
| 1794 |
+
typename... Args,
|
| 1795 |
+
typename = enable_if_t<args_are_all_positional<Args...>()>>
|
| 1796 |
+
simple_collector<policy> collect_arguments(Args &&...args) {
|
| 1797 |
+
return simple_collector<policy>(std::forward<Args>(args)...);
|
| 1798 |
+
}
|
| 1799 |
+
|
| 1800 |
+
/// Collect all arguments, including keywords and unpacking (only instantiated when needed)
|
| 1801 |
+
template <return_value_policy policy,
|
| 1802 |
+
typename... Args,
|
| 1803 |
+
typename = enable_if_t<!args_are_all_positional<Args...>()>>
|
| 1804 |
+
unpacking_collector<policy> collect_arguments(Args &&...args) {
|
| 1805 |
+
// Following argument order rules for generalized unpacking according to PEP 448
|
| 1806 |
+
static_assert(constexpr_last<is_positional, Args...>()
|
| 1807 |
+
< constexpr_first<is_keyword_or_ds, Args...>()
|
| 1808 |
+
&& constexpr_last<is_s_unpacking, Args...>()
|
| 1809 |
+
< constexpr_first<is_ds_unpacking, Args...>(),
|
| 1810 |
+
"Invalid function call: positional args must precede keywords and ** unpacking; "
|
| 1811 |
+
"* unpacking must precede ** unpacking");
|
| 1812 |
+
return unpacking_collector<policy>(std::forward<Args>(args)...);
|
| 1813 |
+
}
|
| 1814 |
+
|
| 1815 |
+
template <typename Derived>
|
| 1816 |
+
template <return_value_policy policy, typename... Args>
|
| 1817 |
+
object object_api<Derived>::operator()(Args &&...args) const {
|
| 1818 |
+
#ifndef NDEBUG
|
| 1819 |
+
if (!PyGILState_Check()) {
|
| 1820 |
+
pybind11_fail("pybind11::object_api<>::operator() PyGILState_Check() failure.");
|
| 1821 |
+
}
|
| 1822 |
+
#endif
|
| 1823 |
+
return detail::collect_arguments<policy>(std::forward<Args>(args)...).call(derived().ptr());
|
| 1824 |
+
}
|
| 1825 |
+
|
| 1826 |
+
template <typename Derived>
|
| 1827 |
+
template <return_value_policy policy, typename... Args>
|
| 1828 |
+
object object_api<Derived>::call(Args &&...args) const {
|
| 1829 |
+
return operator()<policy>(std::forward<Args>(args)...);
|
| 1830 |
+
}
|
| 1831 |
+
|
| 1832 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 1833 |
+
|
| 1834 |
+
template <typename T>
|
| 1835 |
+
handle type::handle_of() {
|
| 1836 |
+
static_assert(std::is_base_of<detail::type_caster_generic, detail::make_caster<T>>::value,
|
| 1837 |
+
"py::type::of<T> only supports the case where T is a registered C++ types.");
|
| 1838 |
+
|
| 1839 |
+
return detail::get_type_handle(typeid(T), true);
|
| 1840 |
+
}
|
| 1841 |
+
|
| 1842 |
+
#define PYBIND11_MAKE_OPAQUE(...) \
|
| 1843 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) \
|
| 1844 |
+
namespace detail { \
|
| 1845 |
+
template <> \
|
| 1846 |
+
class type_caster<__VA_ARGS__> : public type_caster_base<__VA_ARGS__> {}; \
|
| 1847 |
+
} \
|
| 1848 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
| 1849 |
+
|
| 1850 |
+
/// Lets you pass a type containing a `,` through a macro parameter without needing a separate
|
| 1851 |
+
/// typedef, e.g.:
|
| 1852 |
+
/// `PYBIND11_OVERRIDE(PYBIND11_TYPE(ReturnType<A, B>), PYBIND11_TYPE(Parent<C, D>), f, arg)`
|
| 1853 |
+
#define PYBIND11_TYPE(...) __VA_ARGS__
|
| 1854 |
+
|
| 1855 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/chrono.h
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Trent Houliston <trent@houliston.me> and
|
| 5 |
+
Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 6 |
+
|
| 7 |
+
All rights reserved. Use of this source code is governed by a
|
| 8 |
+
BSD-style license that can be found in the LICENSE file.
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include "pybind11.h"
|
| 14 |
+
|
| 15 |
+
#include <chrono>
|
| 16 |
+
#include <cmath>
|
| 17 |
+
#include <ctime>
|
| 18 |
+
#include <datetime.h>
|
| 19 |
+
#include <mutex>
|
| 20 |
+
|
| 21 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 22 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 23 |
+
|
| 24 |
+
template <typename type>
|
| 25 |
+
class duration_caster {
|
| 26 |
+
public:
|
| 27 |
+
using rep = typename type::rep;
|
| 28 |
+
using period = typename type::period;
|
| 29 |
+
|
| 30 |
+
// signed 25 bits required by the standard.
|
| 31 |
+
using days = std::chrono::duration<int_least32_t, std::ratio<86400>>;
|
| 32 |
+
|
| 33 |
+
bool load(handle src, bool) {
|
| 34 |
+
using namespace std::chrono;
|
| 35 |
+
|
| 36 |
+
// Lazy initialise the PyDateTime import
|
| 37 |
+
if (!PyDateTimeAPI) {
|
| 38 |
+
PyDateTime_IMPORT;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
if (!src) {
|
| 42 |
+
return false;
|
| 43 |
+
}
|
| 44 |
+
// If invoked with datetime.delta object
|
| 45 |
+
if (PyDelta_Check(src.ptr())) {
|
| 46 |
+
value = type(duration_cast<duration<rep, period>>(
|
| 47 |
+
days(PyDateTime_DELTA_GET_DAYS(src.ptr()))
|
| 48 |
+
+ seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr()))
|
| 49 |
+
+ microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
|
| 50 |
+
return true;
|
| 51 |
+
}
|
| 52 |
+
// If invoked with a float we assume it is seconds and convert
|
| 53 |
+
if (PyFloat_Check(src.ptr())) {
|
| 54 |
+
value = type(duration_cast<duration<rep, period>>(
|
| 55 |
+
duration<double>(PyFloat_AsDouble(src.ptr()))));
|
| 56 |
+
return true;
|
| 57 |
+
}
|
| 58 |
+
return false;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// If this is a duration just return it back
|
| 62 |
+
static const std::chrono::duration<rep, period> &
|
| 63 |
+
get_duration(const std::chrono::duration<rep, period> &src) {
|
| 64 |
+
return src;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// If this is a time_point get the time_since_epoch
|
| 68 |
+
template <typename Clock>
|
| 69 |
+
static std::chrono::duration<rep, period>
|
| 70 |
+
get_duration(const std::chrono::time_point<Clock, std::chrono::duration<rep, period>> &src) {
|
| 71 |
+
return src.time_since_epoch();
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) {
|
| 75 |
+
using namespace std::chrono;
|
| 76 |
+
|
| 77 |
+
// Use overloaded function to get our duration from our source
|
| 78 |
+
// Works out if it is a duration or time_point and get the duration
|
| 79 |
+
auto d = get_duration(src);
|
| 80 |
+
|
| 81 |
+
// Lazy initialise the PyDateTime import
|
| 82 |
+
if (!PyDateTimeAPI) {
|
| 83 |
+
PyDateTime_IMPORT;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// Declare these special duration types so the conversions happen with the correct
|
| 87 |
+
// primitive types (int)
|
| 88 |
+
using dd_t = duration<int, std::ratio<86400>>;
|
| 89 |
+
using ss_t = duration<int, std::ratio<1>>;
|
| 90 |
+
using us_t = duration<int, std::micro>;
|
| 91 |
+
|
| 92 |
+
auto dd = duration_cast<dd_t>(d);
|
| 93 |
+
auto subd = d - dd;
|
| 94 |
+
auto ss = duration_cast<ss_t>(subd);
|
| 95 |
+
auto us = duration_cast<us_t>(subd - ss);
|
| 96 |
+
return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
PYBIND11_TYPE_CASTER(type, const_name("datetime.timedelta"));
|
| 100 |
+
};
|
| 101 |
+
|
| 102 |
+
inline std::tm *localtime_thread_safe(const std::time_t *time, std::tm *buf) {
|
| 103 |
+
#if (defined(__STDC_LIB_EXT1__) && defined(__STDC_WANT_LIB_EXT1__)) || defined(_MSC_VER)
|
| 104 |
+
if (localtime_s(buf, time))
|
| 105 |
+
return nullptr;
|
| 106 |
+
return buf;
|
| 107 |
+
#else
|
| 108 |
+
static std::mutex mtx;
|
| 109 |
+
std::lock_guard<std::mutex> lock(mtx);
|
| 110 |
+
std::tm *tm_ptr = std::localtime(time);
|
| 111 |
+
if (tm_ptr != nullptr) {
|
| 112 |
+
*buf = *tm_ptr;
|
| 113 |
+
}
|
| 114 |
+
return tm_ptr;
|
| 115 |
+
#endif
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
// This is for casting times on the system clock into datetime.datetime instances
|
| 119 |
+
template <typename Duration>
|
| 120 |
+
class type_caster<std::chrono::time_point<std::chrono::system_clock, Duration>> {
|
| 121 |
+
public:
|
| 122 |
+
using type = std::chrono::time_point<std::chrono::system_clock, Duration>;
|
| 123 |
+
bool load(handle src, bool) {
|
| 124 |
+
using namespace std::chrono;
|
| 125 |
+
|
| 126 |
+
// Lazy initialise the PyDateTime import
|
| 127 |
+
if (!PyDateTimeAPI) {
|
| 128 |
+
PyDateTime_IMPORT;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
if (!src) {
|
| 132 |
+
return false;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
std::tm cal;
|
| 136 |
+
microseconds msecs;
|
| 137 |
+
|
| 138 |
+
if (PyDateTime_Check(src.ptr())) {
|
| 139 |
+
cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
|
| 140 |
+
cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
|
| 141 |
+
cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
|
| 142 |
+
cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
|
| 143 |
+
cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
|
| 144 |
+
cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
|
| 145 |
+
cal.tm_isdst = -1;
|
| 146 |
+
msecs = microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
|
| 147 |
+
} else if (PyDate_Check(src.ptr())) {
|
| 148 |
+
cal.tm_sec = 0;
|
| 149 |
+
cal.tm_min = 0;
|
| 150 |
+
cal.tm_hour = 0;
|
| 151 |
+
cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
|
| 152 |
+
cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
|
| 153 |
+
cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
|
| 154 |
+
cal.tm_isdst = -1;
|
| 155 |
+
msecs = microseconds(0);
|
| 156 |
+
} else if (PyTime_Check(src.ptr())) {
|
| 157 |
+
cal.tm_sec = PyDateTime_TIME_GET_SECOND(src.ptr());
|
| 158 |
+
cal.tm_min = PyDateTime_TIME_GET_MINUTE(src.ptr());
|
| 159 |
+
cal.tm_hour = PyDateTime_TIME_GET_HOUR(src.ptr());
|
| 160 |
+
cal.tm_mday = 1; // This date (day, month, year) = (1, 0, 70)
|
| 161 |
+
cal.tm_mon = 0; // represents 1-Jan-1970, which is the first
|
| 162 |
+
cal.tm_year = 70; // earliest available date for Python's datetime
|
| 163 |
+
cal.tm_isdst = -1;
|
| 164 |
+
msecs = microseconds(PyDateTime_TIME_GET_MICROSECOND(src.ptr()));
|
| 165 |
+
} else {
|
| 166 |
+
return false;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
value = time_point_cast<Duration>(system_clock::from_time_t(std::mktime(&cal)) + msecs);
|
| 170 |
+
return true;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
static handle cast(const std::chrono::time_point<std::chrono::system_clock, Duration> &src,
|
| 174 |
+
return_value_policy /* policy */,
|
| 175 |
+
handle /* parent */) {
|
| 176 |
+
using namespace std::chrono;
|
| 177 |
+
|
| 178 |
+
// Lazy initialise the PyDateTime import
|
| 179 |
+
if (!PyDateTimeAPI) {
|
| 180 |
+
PyDateTime_IMPORT;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
// Get out microseconds, and make sure they are positive, to avoid bug in eastern
|
| 184 |
+
// hemisphere time zones (cfr. https://github.com/pybind/pybind11/issues/2417)
|
| 185 |
+
using us_t = duration<int, std::micro>;
|
| 186 |
+
auto us = duration_cast<us_t>(src.time_since_epoch() % seconds(1));
|
| 187 |
+
if (us.count() < 0) {
|
| 188 |
+
us += seconds(1);
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
// Subtract microseconds BEFORE `system_clock::to_time_t`, because:
|
| 192 |
+
// > If std::time_t has lower precision, it is implementation-defined whether the value is
|
| 193 |
+
// rounded or truncated. (https://en.cppreference.com/w/cpp/chrono/system_clock/to_time_t)
|
| 194 |
+
std::time_t tt
|
| 195 |
+
= system_clock::to_time_t(time_point_cast<system_clock::duration>(src - us));
|
| 196 |
+
|
| 197 |
+
std::tm localtime;
|
| 198 |
+
std::tm *localtime_ptr = localtime_thread_safe(&tt, &localtime);
|
| 199 |
+
if (!localtime_ptr) {
|
| 200 |
+
throw cast_error("Unable to represent system_clock in local time");
|
| 201 |
+
}
|
| 202 |
+
return PyDateTime_FromDateAndTime(localtime.tm_year + 1900,
|
| 203 |
+
localtime.tm_mon + 1,
|
| 204 |
+
localtime.tm_mday,
|
| 205 |
+
localtime.tm_hour,
|
| 206 |
+
localtime.tm_min,
|
| 207 |
+
localtime.tm_sec,
|
| 208 |
+
us.count());
|
| 209 |
+
}
|
| 210 |
+
PYBIND11_TYPE_CASTER(type, const_name("datetime.datetime"));
|
| 211 |
+
};
|
| 212 |
+
|
| 213 |
+
// Other clocks that are not the system clock are not measured as datetime.datetime objects
|
| 214 |
+
// since they are not measured on calendar time. So instead we just make them timedeltas
|
| 215 |
+
// Or if they have passed us a time as a float we convert that
|
| 216 |
+
template <typename Clock, typename Duration>
|
| 217 |
+
class type_caster<std::chrono::time_point<Clock, Duration>>
|
| 218 |
+
: public duration_caster<std::chrono::time_point<Clock, Duration>> {};
|
| 219 |
+
|
| 220 |
+
template <typename Rep, typename Period>
|
| 221 |
+
class type_caster<std::chrono::duration<Rep, Period>>
|
| 222 |
+
: public duration_caster<std::chrono::duration<Rep, Period>> {};
|
| 223 |
+
|
| 224 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 225 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/common.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "detail/common.h"
|
| 2 |
+
#warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."
|
phivenv/Lib/site-packages/torch/include/pybind11/complex.h
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/complex.h: Complex number support
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "pybind11.h"
|
| 13 |
+
|
| 14 |
+
#include <complex>
|
| 15 |
+
|
| 16 |
+
/// glibc defines I as a macro which breaks things, e.g., boost template names
|
| 17 |
+
#ifdef I
|
| 18 |
+
# undef I
|
| 19 |
+
#endif
|
| 20 |
+
|
| 21 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 22 |
+
|
| 23 |
+
template <typename T>
|
| 24 |
+
struct format_descriptor<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
|
| 25 |
+
static constexpr const char c = format_descriptor<T>::c;
|
| 26 |
+
static constexpr const char value[3] = {'Z', c, '\0'};
|
| 27 |
+
static std::string format() { return std::string(value); }
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
#ifndef PYBIND11_CPP17
|
| 31 |
+
|
| 32 |
+
template <typename T>
|
| 33 |
+
constexpr const char
|
| 34 |
+
format_descriptor<std::complex<T>,
|
| 35 |
+
detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
|
| 36 |
+
|
| 37 |
+
#endif
|
| 38 |
+
|
| 39 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 40 |
+
|
| 41 |
+
template <typename T>
|
| 42 |
+
struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
|
| 43 |
+
static constexpr bool value = true;
|
| 44 |
+
static constexpr int index = is_fmt_numeric<T>::index + 3;
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
template <typename T>
|
| 48 |
+
class type_caster<std::complex<T>> {
|
| 49 |
+
public:
|
| 50 |
+
bool load(handle src, bool convert) {
|
| 51 |
+
if (!src) {
|
| 52 |
+
return false;
|
| 53 |
+
}
|
| 54 |
+
if (!convert && !PyComplex_Check(src.ptr())) {
|
| 55 |
+
return false;
|
| 56 |
+
}
|
| 57 |
+
Py_complex result = PyComplex_AsCComplex(src.ptr());
|
| 58 |
+
if (result.real == -1.0 && PyErr_Occurred()) {
|
| 59 |
+
PyErr_Clear();
|
| 60 |
+
return false;
|
| 61 |
+
}
|
| 62 |
+
value = std::complex<T>((T) result.real, (T) result.imag);
|
| 63 |
+
return true;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
static handle
|
| 67 |
+
cast(const std::complex<T> &src, return_value_policy /* policy */, handle /* parent */) {
|
| 68 |
+
return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
PYBIND11_TYPE_CASTER(std::complex<T>, const_name("complex"));
|
| 72 |
+
};
|
| 73 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 74 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/detail/class.h
ADDED
|
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/detail/class.h: Python C API implementation details for py::class_
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <pybind11/attr.h>
|
| 13 |
+
#include <pybind11/options.h>
|
| 14 |
+
|
| 15 |
+
#include "exception_translation.h"
|
| 16 |
+
|
| 17 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 18 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 19 |
+
|
| 20 |
+
#if !defined(PYPY_VERSION)
|
| 21 |
+
# define PYBIND11_BUILTIN_QUALNAME
|
| 22 |
+
# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj)
|
| 23 |
+
#else
|
| 24 |
+
// In PyPy, we still set __qualname__ so that we can produce reliable function type
|
| 25 |
+
// signatures; in CPython this macro expands to nothing:
|
| 26 |
+
# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) \
|
| 27 |
+
setattr((PyObject *) obj, "__qualname__", nameobj)
|
| 28 |
+
#endif
|
| 29 |
+
|
| 30 |
+
inline std::string get_fully_qualified_tp_name(PyTypeObject *type) {
|
| 31 |
+
#if !defined(PYPY_VERSION)
|
| 32 |
+
return type->tp_name;
|
| 33 |
+
#else
|
| 34 |
+
auto module_name = handle((PyObject *) type).attr("__module__").cast<std::string>();
|
| 35 |
+
if (module_name == PYBIND11_BUILTINS_MODULE)
|
| 36 |
+
return type->tp_name;
|
| 37 |
+
else
|
| 38 |
+
return std::move(module_name) + "." + type->tp_name;
|
| 39 |
+
#endif
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
inline PyTypeObject *type_incref(PyTypeObject *type) {
|
| 43 |
+
Py_INCREF(type);
|
| 44 |
+
return type;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
#if !defined(PYPY_VERSION)
|
| 48 |
+
|
| 49 |
+
/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance.
|
| 50 |
+
extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) {
|
| 51 |
+
return PyProperty_Type.tp_descr_get(self, cls, cls);
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
/// `pybind11_static_property.__set__()`: Just like the above `__get__()`.
|
| 55 |
+
extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) {
|
| 56 |
+
PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj);
|
| 57 |
+
return PyProperty_Type.tp_descr_set(self, cls, value);
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
// Forward declaration to use in `make_static_property_type()`
|
| 61 |
+
inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type);
|
| 62 |
+
|
| 63 |
+
/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()`
|
| 64 |
+
methods are modified to always use the object type instead of a concrete instance.
|
| 65 |
+
Return value: New reference. */
|
| 66 |
+
inline PyTypeObject *make_static_property_type() {
|
| 67 |
+
constexpr auto *name = "pybind11_static_property";
|
| 68 |
+
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
|
| 69 |
+
|
| 70 |
+
/* Danger zone: from now (and until PyType_Ready), make sure to
|
| 71 |
+
issue no Python C API calls which could potentially invoke the
|
| 72 |
+
garbage collector (the GC will call type_traverse(), which will in
|
| 73 |
+
turn find the newly constructed type in an invalid state) */
|
| 74 |
+
auto *heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
|
| 75 |
+
if (!heap_type) {
|
| 76 |
+
pybind11_fail("make_static_property_type(): error allocating type!");
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
heap_type->ht_name = name_obj.inc_ref().ptr();
|
| 80 |
+
# ifdef PYBIND11_BUILTIN_QUALNAME
|
| 81 |
+
heap_type->ht_qualname = name_obj.inc_ref().ptr();
|
| 82 |
+
# endif
|
| 83 |
+
|
| 84 |
+
auto *type = &heap_type->ht_type;
|
| 85 |
+
type->tp_name = name;
|
| 86 |
+
type->tp_base = type_incref(&PyProperty_Type);
|
| 87 |
+
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
|
| 88 |
+
type->tp_descr_get = pybind11_static_get;
|
| 89 |
+
type->tp_descr_set = pybind11_static_set;
|
| 90 |
+
|
| 91 |
+
# if PY_VERSION_HEX >= 0x030C0000
|
| 92 |
+
// Since Python-3.12 property-derived types are required to
|
| 93 |
+
// have dynamic attributes (to set `__doc__`)
|
| 94 |
+
enable_dynamic_attributes(heap_type);
|
| 95 |
+
# endif
|
| 96 |
+
|
| 97 |
+
if (PyType_Ready(type) < 0) {
|
| 98 |
+
pybind11_fail("make_static_property_type(): failure in PyType_Ready()!");
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
|
| 102 |
+
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
|
| 103 |
+
|
| 104 |
+
return type;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
#else // PYPY
|
| 108 |
+
|
| 109 |
+
/** PyPy has some issues with the above C API, so we evaluate Python code instead.
|
| 110 |
+
This function will only be called once so performance isn't really a concern.
|
| 111 |
+
Return value: New reference. */
|
| 112 |
+
inline PyTypeObject *make_static_property_type() {
|
| 113 |
+
auto d = dict();
|
| 114 |
+
PyObject *result = PyRun_String(R"(\
|
| 115 |
+
class pybind11_static_property(property):
|
| 116 |
+
def __get__(self, obj, cls):
|
| 117 |
+
return property.__get__(self, cls, cls)
|
| 118 |
+
|
| 119 |
+
def __set__(self, obj, value):
|
| 120 |
+
cls = obj if isinstance(obj, type) else type(obj)
|
| 121 |
+
property.__set__(self, cls, value)
|
| 122 |
+
)",
|
| 123 |
+
Py_file_input,
|
| 124 |
+
d.ptr(),
|
| 125 |
+
d.ptr());
|
| 126 |
+
if (result == nullptr)
|
| 127 |
+
throw error_already_set();
|
| 128 |
+
Py_DECREF(result);
|
| 129 |
+
return (PyTypeObject *) d["pybind11_static_property"].cast<object>().release().ptr();
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
#endif // PYPY
|
| 133 |
+
|
| 134 |
+
/** Types with static properties need to handle `Type.static_prop = x` in a specific way.
|
| 135 |
+
By default, Python replaces the `static_property` itself, but for wrapped C++ types
|
| 136 |
+
we need to call `static_property.__set__()` in order to propagate the new value to
|
| 137 |
+
the underlying C++ data structure. */
|
| 138 |
+
extern "C" inline int pybind11_meta_setattro(PyObject *obj, PyObject *name, PyObject *value) {
|
| 139 |
+
// Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw
|
| 140 |
+
// descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`).
|
| 141 |
+
PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
|
| 142 |
+
|
| 143 |
+
// The following assignment combinations are possible:
|
| 144 |
+
// 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)`
|
| 145 |
+
// 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop`
|
| 146 |
+
// 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment
|
| 147 |
+
auto *const static_prop = (PyObject *) get_internals().static_property_type;
|
| 148 |
+
const auto call_descr_set = (descr != nullptr) && (value != nullptr)
|
| 149 |
+
&& (PyObject_IsInstance(descr, static_prop) != 0)
|
| 150 |
+
&& (PyObject_IsInstance(value, static_prop) == 0);
|
| 151 |
+
if (call_descr_set) {
|
| 152 |
+
// Call `static_property.__set__()` instead of replacing the `static_property`.
|
| 153 |
+
#if !defined(PYPY_VERSION)
|
| 154 |
+
return Py_TYPE(descr)->tp_descr_set(descr, obj, value);
|
| 155 |
+
#else
|
| 156 |
+
if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) {
|
| 157 |
+
Py_DECREF(result);
|
| 158 |
+
return 0;
|
| 159 |
+
} else {
|
| 160 |
+
return -1;
|
| 161 |
+
}
|
| 162 |
+
#endif
|
| 163 |
+
} else {
|
| 164 |
+
// Replace existing attribute.
|
| 165 |
+
return PyType_Type.tp_setattro(obj, name, value);
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
/**
|
| 170 |
+
* Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing
|
| 171 |
+
* methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function,
|
| 172 |
+
* when called on a class, or a PyMethod, when called on an instance. Override that behaviour here
|
| 173 |
+
* to do a special case bypass for PyInstanceMethod_Types.
|
| 174 |
+
*/
|
| 175 |
+
extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) {
|
| 176 |
+
PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
|
| 177 |
+
if (descr && PyInstanceMethod_Check(descr)) {
|
| 178 |
+
Py_INCREF(descr);
|
| 179 |
+
return descr;
|
| 180 |
+
}
|
| 181 |
+
return PyType_Type.tp_getattro(obj, name);
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
/// metaclass `__call__` function that is used to create all pybind11 objects.
|
| 185 |
+
extern "C" inline PyObject *pybind11_meta_call(PyObject *type, PyObject *args, PyObject *kwargs) {
|
| 186 |
+
|
| 187 |
+
// use the default metaclass call to create/initialize the object
|
| 188 |
+
PyObject *self = PyType_Type.tp_call(type, args, kwargs);
|
| 189 |
+
if (self == nullptr) {
|
| 190 |
+
return nullptr;
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
// Ensure that the base __init__ function(s) were called
|
| 194 |
+
values_and_holders vhs(self);
|
| 195 |
+
for (const auto &vh : vhs) {
|
| 196 |
+
if (!vh.holder_constructed() && !vhs.is_redundant_value_and_holder(vh)) {
|
| 197 |
+
PyErr_Format(PyExc_TypeError,
|
| 198 |
+
"%.200s.__init__() must be called when overriding __init__",
|
| 199 |
+
get_fully_qualified_tp_name(vh.type->type).c_str());
|
| 200 |
+
Py_DECREF(self);
|
| 201 |
+
return nullptr;
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
return self;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
/// Cleanup the type-info for a pybind11-registered type.
|
| 209 |
+
extern "C" inline void pybind11_meta_dealloc(PyObject *obj) {
|
| 210 |
+
with_internals([obj](internals &internals) {
|
| 211 |
+
auto *type = (PyTypeObject *) obj;
|
| 212 |
+
|
| 213 |
+
// A pybind11-registered type will:
|
| 214 |
+
// 1) be found in internals.registered_types_py
|
| 215 |
+
// 2) have exactly one associated `detail::type_info`
|
| 216 |
+
auto found_type = internals.registered_types_py.find(type);
|
| 217 |
+
if (found_type != internals.registered_types_py.end() && found_type->second.size() == 1
|
| 218 |
+
&& found_type->second[0]->type == type) {
|
| 219 |
+
|
| 220 |
+
auto *tinfo = found_type->second[0];
|
| 221 |
+
auto tindex = std::type_index(*tinfo->cpptype);
|
| 222 |
+
internals.direct_conversions.erase(tindex);
|
| 223 |
+
|
| 224 |
+
if (tinfo->module_local) {
|
| 225 |
+
get_local_internals().registered_types_cpp.erase(tindex);
|
| 226 |
+
} else {
|
| 227 |
+
internals.registered_types_cpp.erase(tindex);
|
| 228 |
+
}
|
| 229 |
+
internals.registered_types_py.erase(tinfo->type);
|
| 230 |
+
|
| 231 |
+
// Actually just `std::erase_if`, but that's only available in C++20
|
| 232 |
+
auto &cache = internals.inactive_override_cache;
|
| 233 |
+
for (auto it = cache.begin(), last = cache.end(); it != last;) {
|
| 234 |
+
if (it->first == (PyObject *) tinfo->type) {
|
| 235 |
+
it = cache.erase(it);
|
| 236 |
+
} else {
|
| 237 |
+
++it;
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
delete tinfo;
|
| 242 |
+
}
|
| 243 |
+
});
|
| 244 |
+
|
| 245 |
+
PyType_Type.tp_dealloc(obj);
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
/** This metaclass is assigned by default to all pybind11 types and is required in order
|
| 249 |
+
for static properties to function correctly. Users may override this using `py::metaclass`.
|
| 250 |
+
Return value: New reference. */
|
| 251 |
+
inline PyTypeObject *make_default_metaclass() {
|
| 252 |
+
constexpr auto *name = "pybind11_type";
|
| 253 |
+
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
|
| 254 |
+
|
| 255 |
+
/* Danger zone: from now (and until PyType_Ready), make sure to
|
| 256 |
+
issue no Python C API calls which could potentially invoke the
|
| 257 |
+
garbage collector (the GC will call type_traverse(), which will in
|
| 258 |
+
turn find the newly constructed type in an invalid state) */
|
| 259 |
+
auto *heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
|
| 260 |
+
if (!heap_type) {
|
| 261 |
+
pybind11_fail("make_default_metaclass(): error allocating metaclass!");
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
heap_type->ht_name = name_obj.inc_ref().ptr();
|
| 265 |
+
#ifdef PYBIND11_BUILTIN_QUALNAME
|
| 266 |
+
heap_type->ht_qualname = name_obj.inc_ref().ptr();
|
| 267 |
+
#endif
|
| 268 |
+
|
| 269 |
+
auto *type = &heap_type->ht_type;
|
| 270 |
+
type->tp_name = name;
|
| 271 |
+
type->tp_base = type_incref(&PyType_Type);
|
| 272 |
+
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
|
| 273 |
+
|
| 274 |
+
type->tp_call = pybind11_meta_call;
|
| 275 |
+
|
| 276 |
+
type->tp_setattro = pybind11_meta_setattro;
|
| 277 |
+
type->tp_getattro = pybind11_meta_getattro;
|
| 278 |
+
|
| 279 |
+
type->tp_dealloc = pybind11_meta_dealloc;
|
| 280 |
+
|
| 281 |
+
if (PyType_Ready(type) < 0) {
|
| 282 |
+
pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!");
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
|
| 286 |
+
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
|
| 287 |
+
|
| 288 |
+
return type;
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
/// For multiple inheritance types we need to recursively register/deregister base pointers for any
|
| 292 |
+
/// base classes with pointers that are difference from the instance value pointer so that we can
|
| 293 |
+
/// correctly recognize an offset base class pointer. This calls a function with any offset base
|
| 294 |
+
/// ptrs.
|
| 295 |
+
inline void traverse_offset_bases(void *valueptr,
|
| 296 |
+
const detail::type_info *tinfo,
|
| 297 |
+
instance *self,
|
| 298 |
+
bool (*f)(void * /*parentptr*/, instance * /*self*/)) {
|
| 299 |
+
for (handle h : reinterpret_borrow<tuple>(tinfo->type->tp_bases)) {
|
| 300 |
+
if (auto *parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) {
|
| 301 |
+
for (auto &c : parent_tinfo->implicit_casts) {
|
| 302 |
+
if (c.first == tinfo->cpptype) {
|
| 303 |
+
auto *parentptr = c.second(valueptr);
|
| 304 |
+
if (parentptr != valueptr) {
|
| 305 |
+
f(parentptr, self);
|
| 306 |
+
}
|
| 307 |
+
traverse_offset_bases(parentptr, parent_tinfo, self, f);
|
| 308 |
+
break;
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
inline bool register_instance_impl(void *ptr, instance *self) {
|
| 316 |
+
with_instance_map(ptr, [&](instance_map &instances) { instances.emplace(ptr, self); });
|
| 317 |
+
return true; // unused, but gives the same signature as the deregister func
|
| 318 |
+
}
|
| 319 |
+
inline bool deregister_instance_impl(void *ptr, instance *self) {
|
| 320 |
+
return with_instance_map(ptr, [&](instance_map &instances) {
|
| 321 |
+
auto range = instances.equal_range(ptr);
|
| 322 |
+
for (auto it = range.first; it != range.second; ++it) {
|
| 323 |
+
if (self == it->second) {
|
| 324 |
+
instances.erase(it);
|
| 325 |
+
return true;
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
return false;
|
| 329 |
+
});
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
inline void register_instance(instance *self, void *valptr, const type_info *tinfo) {
|
| 333 |
+
register_instance_impl(valptr, self);
|
| 334 |
+
if (!tinfo->simple_ancestors) {
|
| 335 |
+
traverse_offset_bases(valptr, tinfo, self, register_instance_impl);
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) {
|
| 340 |
+
bool ret = deregister_instance_impl(valptr, self);
|
| 341 |
+
if (!tinfo->simple_ancestors) {
|
| 342 |
+
traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl);
|
| 343 |
+
}
|
| 344 |
+
return ret;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
/// Instance creation function for all pybind11 types. It allocates the internal instance layout
|
| 348 |
+
/// for holding C++ objects and holders. Allocation is done lazily (the first time the instance is
|
| 349 |
+
/// cast to a reference or pointer), and initialization is done by an `__init__` function.
|
| 350 |
+
inline PyObject *make_new_instance(PyTypeObject *type) {
|
| 351 |
+
#if defined(PYPY_VERSION)
|
| 352 |
+
// PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first
|
| 353 |
+
// inherited object is a plain Python type (i.e. not derived from an extension type). Fix it.
|
| 354 |
+
ssize_t instance_size = static_cast<ssize_t>(sizeof(instance));
|
| 355 |
+
if (type->tp_basicsize < instance_size) {
|
| 356 |
+
type->tp_basicsize = instance_size;
|
| 357 |
+
}
|
| 358 |
+
#endif
|
| 359 |
+
PyObject *self = type->tp_alloc(type, 0);
|
| 360 |
+
auto *inst = reinterpret_cast<instance *>(self);
|
| 361 |
+
// Allocate the value/holder internals:
|
| 362 |
+
inst->allocate_layout();
|
| 363 |
+
|
| 364 |
+
return self;
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
/// Instance creation function for all pybind11 types. It only allocates space for the
|
| 368 |
+
/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
|
| 369 |
+
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
|
| 370 |
+
return make_new_instance(type);
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
/// An `__init__` function constructs the C++ object. Users should provide at least one
|
| 374 |
+
/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the
|
| 375 |
+
/// following default function will be used which simply throws an exception.
|
| 376 |
+
extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) {
|
| 377 |
+
PyTypeObject *type = Py_TYPE(self);
|
| 378 |
+
std::string msg = get_fully_qualified_tp_name(type) + ": No constructor defined!";
|
| 379 |
+
set_error(PyExc_TypeError, msg.c_str());
|
| 380 |
+
return -1;
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
inline void add_patient(PyObject *nurse, PyObject *patient) {
|
| 384 |
+
auto *instance = reinterpret_cast<detail::instance *>(nurse);
|
| 385 |
+
instance->has_patients = true;
|
| 386 |
+
Py_INCREF(patient);
|
| 387 |
+
|
| 388 |
+
with_internals([&](internals &internals) { internals.patients[nurse].push_back(patient); });
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
inline void clear_patients(PyObject *self) {
|
| 392 |
+
auto *instance = reinterpret_cast<detail::instance *>(self);
|
| 393 |
+
std::vector<PyObject *> patients;
|
| 394 |
+
|
| 395 |
+
with_internals([&](internals &internals) {
|
| 396 |
+
auto pos = internals.patients.find(self);
|
| 397 |
+
|
| 398 |
+
if (pos == internals.patients.end()) {
|
| 399 |
+
pybind11_fail(
|
| 400 |
+
"FATAL: Internal consistency check failed: Invalid clear_patients() call.");
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
// Clearing the patients can cause more Python code to run, which
|
| 404 |
+
// can invalidate the iterator. Extract the vector of patients
|
| 405 |
+
// from the unordered_map first.
|
| 406 |
+
patients = std::move(pos->second);
|
| 407 |
+
internals.patients.erase(pos);
|
| 408 |
+
});
|
| 409 |
+
|
| 410 |
+
instance->has_patients = false;
|
| 411 |
+
for (PyObject *&patient : patients) {
|
| 412 |
+
Py_CLEAR(patient);
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
/// Clears all internal data from the instance and removes it from registered instances in
|
| 417 |
+
/// preparation for deallocation.
|
| 418 |
+
inline void clear_instance(PyObject *self) {
|
| 419 |
+
auto *instance = reinterpret_cast<detail::instance *>(self);
|
| 420 |
+
|
| 421 |
+
// Deallocate any values/holders, if present:
|
| 422 |
+
for (auto &v_h : values_and_holders(instance)) {
|
| 423 |
+
if (v_h) {
|
| 424 |
+
|
| 425 |
+
// We have to deregister before we call dealloc because, for virtual MI types, we still
|
| 426 |
+
// need to be able to get the parent pointers.
|
| 427 |
+
if (v_h.instance_registered()
|
| 428 |
+
&& !deregister_instance(instance, v_h.value_ptr(), v_h.type)) {
|
| 429 |
+
pybind11_fail(
|
| 430 |
+
"pybind11_object_dealloc(): Tried to deallocate unregistered instance!");
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
if (instance->owned || v_h.holder_constructed()) {
|
| 434 |
+
v_h.type->dealloc(v_h);
|
| 435 |
+
}
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
// Deallocate the value/holder layout internals:
|
| 439 |
+
instance->deallocate_layout();
|
| 440 |
+
|
| 441 |
+
if (instance->weakrefs) {
|
| 442 |
+
PyObject_ClearWeakRefs(self);
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
PyObject **dict_ptr = _PyObject_GetDictPtr(self);
|
| 446 |
+
if (dict_ptr) {
|
| 447 |
+
Py_CLEAR(*dict_ptr);
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
if (instance->has_patients) {
|
| 451 |
+
clear_patients(self);
|
| 452 |
+
}
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc`
|
| 456 |
+
/// to destroy the C++ object itself, while the rest is Python bookkeeping.
|
| 457 |
+
extern "C" inline void pybind11_object_dealloc(PyObject *self) {
|
| 458 |
+
auto *type = Py_TYPE(self);
|
| 459 |
+
|
| 460 |
+
// If this is a GC tracked object, untrack it first
|
| 461 |
+
// Note that the track call is implicitly done by the
|
| 462 |
+
// default tp_alloc, which we never override.
|
| 463 |
+
if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
|
| 464 |
+
PyObject_GC_UnTrack(self);
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
clear_instance(self);
|
| 468 |
+
|
| 469 |
+
type->tp_free(self);
|
| 470 |
+
|
| 471 |
+
#if PY_VERSION_HEX < 0x03080000
|
| 472 |
+
// `type->tp_dealloc != pybind11_object_dealloc` means that we're being called
|
| 473 |
+
// as part of a derived type's dealloc, in which case we're not allowed to decref
|
| 474 |
+
// the type here. For cross-module compatibility, we shouldn't compare directly
|
| 475 |
+
// with `pybind11_object_dealloc`, but with the common one stashed in internals.
|
| 476 |
+
auto pybind11_object_type = (PyTypeObject *) get_internals().instance_base;
|
| 477 |
+
if (type->tp_dealloc == pybind11_object_type->tp_dealloc)
|
| 478 |
+
Py_DECREF(type);
|
| 479 |
+
#else
|
| 480 |
+
// This was not needed before Python 3.8 (Python issue 35810)
|
| 481 |
+
// https://github.com/pybind/pybind11/issues/1946
|
| 482 |
+
Py_DECREF(type);
|
| 483 |
+
#endif
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
std::string error_string();
|
| 487 |
+
|
| 488 |
+
/** Create the type which can be used as a common base for all classes. This is
|
| 489 |
+
needed in order to satisfy Python's requirements for multiple inheritance.
|
| 490 |
+
Return value: New reference. */
|
| 491 |
+
inline PyObject *make_object_base_type(PyTypeObject *metaclass) {
|
| 492 |
+
constexpr auto *name = "pybind11_object";
|
| 493 |
+
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
|
| 494 |
+
|
| 495 |
+
/* Danger zone: from now (and until PyType_Ready), make sure to
|
| 496 |
+
issue no Python C API calls which could potentially invoke the
|
| 497 |
+
garbage collector (the GC will call type_traverse(), which will in
|
| 498 |
+
turn find the newly constructed type in an invalid state) */
|
| 499 |
+
auto *heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
|
| 500 |
+
if (!heap_type) {
|
| 501 |
+
pybind11_fail("make_object_base_type(): error allocating type!");
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
heap_type->ht_name = name_obj.inc_ref().ptr();
|
| 505 |
+
#ifdef PYBIND11_BUILTIN_QUALNAME
|
| 506 |
+
heap_type->ht_qualname = name_obj.inc_ref().ptr();
|
| 507 |
+
#endif
|
| 508 |
+
|
| 509 |
+
auto *type = &heap_type->ht_type;
|
| 510 |
+
type->tp_name = name;
|
| 511 |
+
type->tp_base = type_incref(&PyBaseObject_Type);
|
| 512 |
+
type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
|
| 513 |
+
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
|
| 514 |
+
|
| 515 |
+
type->tp_new = pybind11_object_new;
|
| 516 |
+
type->tp_init = pybind11_object_init;
|
| 517 |
+
type->tp_dealloc = pybind11_object_dealloc;
|
| 518 |
+
|
| 519 |
+
/* Support weak references (needed for the keep_alive feature) */
|
| 520 |
+
type->tp_weaklistoffset = offsetof(instance, weakrefs);
|
| 521 |
+
|
| 522 |
+
if (PyType_Ready(type) < 0) {
|
| 523 |
+
pybind11_fail("PyType_Ready failed in make_object_base_type(): " + error_string());
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
|
| 527 |
+
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
|
| 528 |
+
|
| 529 |
+
assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
|
| 530 |
+
return (PyObject *) heap_type;
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`.
|
| 534 |
+
extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) {
|
| 535 |
+
#if PY_VERSION_HEX >= 0x030D0000
|
| 536 |
+
PyObject_VisitManagedDict(self, visit, arg);
|
| 537 |
+
#else
|
| 538 |
+
PyObject *&dict = *_PyObject_GetDictPtr(self);
|
| 539 |
+
Py_VISIT(dict);
|
| 540 |
+
#endif
|
| 541 |
+
// https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse
|
| 542 |
+
#if PY_VERSION_HEX >= 0x03090000
|
| 543 |
+
Py_VISIT(Py_TYPE(self));
|
| 544 |
+
#endif
|
| 545 |
+
return 0;
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
/// dynamic_attr: Allow the GC to clear the dictionary.
|
| 549 |
+
extern "C" inline int pybind11_clear(PyObject *self) {
|
| 550 |
+
#if PY_VERSION_HEX >= 0x030D0000
|
| 551 |
+
PyObject_ClearManagedDict(self);
|
| 552 |
+
#else
|
| 553 |
+
PyObject *&dict = *_PyObject_GetDictPtr(self);
|
| 554 |
+
Py_CLEAR(dict);
|
| 555 |
+
#endif
|
| 556 |
+
return 0;
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
/// Give instances of this type a `__dict__` and opt into garbage collection.
|
| 560 |
+
inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) {
|
| 561 |
+
auto *type = &heap_type->ht_type;
|
| 562 |
+
type->tp_flags |= Py_TPFLAGS_HAVE_GC;
|
| 563 |
+
#if PY_VERSION_HEX < 0x030B0000
|
| 564 |
+
type->tp_dictoffset = type->tp_basicsize; // place dict at the end
|
| 565 |
+
type->tp_basicsize += (ssize_t) sizeof(PyObject *); // and allocate enough space for it
|
| 566 |
+
#else
|
| 567 |
+
type->tp_flags |= Py_TPFLAGS_MANAGED_DICT;
|
| 568 |
+
#endif
|
| 569 |
+
type->tp_traverse = pybind11_traverse;
|
| 570 |
+
type->tp_clear = pybind11_clear;
|
| 571 |
+
|
| 572 |
+
static PyGetSetDef getset[]
|
| 573 |
+
= {{"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr},
|
| 574 |
+
{nullptr, nullptr, nullptr, nullptr, nullptr}};
|
| 575 |
+
type->tp_getset = getset;
|
| 576 |
+
}
|
| 577 |
+
|
| 578 |
+
/// buffer_protocol: Fill in the view as specified by flags.
|
| 579 |
+
extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
|
| 580 |
+
// Look for a `get_buffer` implementation in this type's info or any bases (following MRO).
|
| 581 |
+
type_info *tinfo = nullptr;
|
| 582 |
+
for (auto type : reinterpret_borrow<tuple>(Py_TYPE(obj)->tp_mro)) {
|
| 583 |
+
tinfo = get_type_info((PyTypeObject *) type.ptr());
|
| 584 |
+
if (tinfo && tinfo->get_buffer) {
|
| 585 |
+
break;
|
| 586 |
+
}
|
| 587 |
+
}
|
| 588 |
+
if (view == nullptr || !tinfo || !tinfo->get_buffer) {
|
| 589 |
+
if (view) {
|
| 590 |
+
view->obj = nullptr;
|
| 591 |
+
}
|
| 592 |
+
set_error(PyExc_BufferError, "pybind11_getbuffer(): Internal error");
|
| 593 |
+
return -1;
|
| 594 |
+
}
|
| 595 |
+
std::memset(view, 0, sizeof(Py_buffer));
|
| 596 |
+
buffer_info *info = nullptr;
|
| 597 |
+
try {
|
| 598 |
+
info = tinfo->get_buffer(obj, tinfo->get_buffer_data);
|
| 599 |
+
} catch (...) {
|
| 600 |
+
try_translate_exceptions();
|
| 601 |
+
raise_from(PyExc_BufferError, "Error getting buffer");
|
| 602 |
+
return -1;
|
| 603 |
+
}
|
| 604 |
+
if (info == nullptr) {
|
| 605 |
+
pybind11_fail("FATAL UNEXPECTED SITUATION: tinfo->get_buffer() returned nullptr.");
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE && info->readonly) {
|
| 609 |
+
delete info;
|
| 610 |
+
// view->obj = nullptr; // Was just memset to 0, so not necessary
|
| 611 |
+
set_error(PyExc_BufferError, "Writable buffer requested for readonly storage");
|
| 612 |
+
return -1;
|
| 613 |
+
}
|
| 614 |
+
view->obj = obj;
|
| 615 |
+
view->ndim = 1;
|
| 616 |
+
view->internal = info;
|
| 617 |
+
view->buf = info->ptr;
|
| 618 |
+
view->itemsize = info->itemsize;
|
| 619 |
+
view->len = view->itemsize;
|
| 620 |
+
for (auto s : info->shape) {
|
| 621 |
+
view->len *= s;
|
| 622 |
+
}
|
| 623 |
+
view->readonly = static_cast<int>(info->readonly);
|
| 624 |
+
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
|
| 625 |
+
view->format = const_cast<char *>(info->format.c_str());
|
| 626 |
+
}
|
| 627 |
+
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
|
| 628 |
+
view->ndim = (int) info->ndim;
|
| 629 |
+
view->strides = info->strides.data();
|
| 630 |
+
view->shape = info->shape.data();
|
| 631 |
+
}
|
| 632 |
+
Py_INCREF(view->obj);
|
| 633 |
+
return 0;
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
/// buffer_protocol: Release the resources of the buffer.
|
| 637 |
+
extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) {
|
| 638 |
+
delete (buffer_info *) view->internal;
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
/// Give this type a buffer interface.
|
| 642 |
+
inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) {
|
| 643 |
+
heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer;
|
| 644 |
+
|
| 645 |
+
heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer;
|
| 646 |
+
heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer;
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
/** Create a brand new Python type according to the `type_record` specification.
|
| 650 |
+
Return value: New reference. */
|
| 651 |
+
inline PyObject *make_new_python_type(const type_record &rec) {
|
| 652 |
+
auto name = reinterpret_steal<object>(PYBIND11_FROM_STRING(rec.name));
|
| 653 |
+
|
| 654 |
+
auto qualname = name;
|
| 655 |
+
if (rec.scope && !PyModule_Check(rec.scope.ptr()) && hasattr(rec.scope, "__qualname__")) {
|
| 656 |
+
qualname = reinterpret_steal<object>(
|
| 657 |
+
PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr()));
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
object module_;
|
| 661 |
+
if (rec.scope) {
|
| 662 |
+
if (hasattr(rec.scope, "__module__")) {
|
| 663 |
+
module_ = rec.scope.attr("__module__");
|
| 664 |
+
} else if (hasattr(rec.scope, "__name__")) {
|
| 665 |
+
module_ = rec.scope.attr("__name__");
|
| 666 |
+
}
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
const auto *full_name = c_str(
|
| 670 |
+
#if !defined(PYPY_VERSION)
|
| 671 |
+
module_ ? str(module_).cast<std::string>() + "." + rec.name :
|
| 672 |
+
#endif
|
| 673 |
+
rec.name);
|
| 674 |
+
|
| 675 |
+
char *tp_doc = nullptr;
|
| 676 |
+
if (rec.doc && options::show_user_defined_docstrings()) {
|
| 677 |
+
/* Allocate memory for docstring (Python will free this later on) */
|
| 678 |
+
size_t size = std::strlen(rec.doc) + 1;
|
| 679 |
+
#if PY_VERSION_HEX >= 0x030D0000
|
| 680 |
+
tp_doc = (char *) PyMem_MALLOC(size);
|
| 681 |
+
#else
|
| 682 |
+
tp_doc = (char *) PyObject_MALLOC(size);
|
| 683 |
+
#endif
|
| 684 |
+
std::memcpy((void *) tp_doc, rec.doc, size);
|
| 685 |
+
}
|
| 686 |
+
|
| 687 |
+
auto &internals = get_internals();
|
| 688 |
+
auto bases = tuple(rec.bases);
|
| 689 |
+
auto *base = (bases.empty()) ? internals.instance_base : bases[0].ptr();
|
| 690 |
+
|
| 691 |
+
/* Danger zone: from now (and until PyType_Ready), make sure to
|
| 692 |
+
issue no Python C API calls which could potentially invoke the
|
| 693 |
+
garbage collector (the GC will call type_traverse(), which will in
|
| 694 |
+
turn find the newly constructed type in an invalid state) */
|
| 695 |
+
auto *metaclass
|
| 696 |
+
= rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr() : internals.default_metaclass;
|
| 697 |
+
|
| 698 |
+
auto *heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
|
| 699 |
+
if (!heap_type) {
|
| 700 |
+
pybind11_fail(std::string(rec.name) + ": Unable to create type object!");
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
heap_type->ht_name = name.release().ptr();
|
| 704 |
+
#ifdef PYBIND11_BUILTIN_QUALNAME
|
| 705 |
+
heap_type->ht_qualname = qualname.inc_ref().ptr();
|
| 706 |
+
#endif
|
| 707 |
+
|
| 708 |
+
auto *type = &heap_type->ht_type;
|
| 709 |
+
type->tp_name = full_name;
|
| 710 |
+
type->tp_doc = tp_doc;
|
| 711 |
+
type->tp_base = type_incref((PyTypeObject *) base);
|
| 712 |
+
type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
|
| 713 |
+
if (!bases.empty()) {
|
| 714 |
+
type->tp_bases = bases.release().ptr();
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
/* Don't inherit base __init__ */
|
| 718 |
+
type->tp_init = pybind11_object_init;
|
| 719 |
+
|
| 720 |
+
/* Supported protocols */
|
| 721 |
+
type->tp_as_number = &heap_type->as_number;
|
| 722 |
+
type->tp_as_sequence = &heap_type->as_sequence;
|
| 723 |
+
type->tp_as_mapping = &heap_type->as_mapping;
|
| 724 |
+
type->tp_as_async = &heap_type->as_async;
|
| 725 |
+
|
| 726 |
+
/* Flags */
|
| 727 |
+
type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE;
|
| 728 |
+
if (!rec.is_final) {
|
| 729 |
+
type->tp_flags |= Py_TPFLAGS_BASETYPE;
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
if (rec.dynamic_attr) {
|
| 733 |
+
enable_dynamic_attributes(heap_type);
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
if (rec.buffer_protocol) {
|
| 737 |
+
enable_buffer_protocol(heap_type);
|
| 738 |
+
}
|
| 739 |
+
|
| 740 |
+
if (rec.custom_type_setup_callback) {
|
| 741 |
+
rec.custom_type_setup_callback(heap_type);
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
if (PyType_Ready(type) < 0) {
|
| 745 |
+
pybind11_fail(std::string(rec.name) + ": PyType_Ready failed: " + error_string());
|
| 746 |
+
}
|
| 747 |
+
|
| 748 |
+
assert(!rec.dynamic_attr || PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
|
| 749 |
+
|
| 750 |
+
/* Register type with the parent scope */
|
| 751 |
+
if (rec.scope) {
|
| 752 |
+
setattr(rec.scope, rec.name, (PyObject *) type);
|
| 753 |
+
} else {
|
| 754 |
+
Py_INCREF(type); // Keep it alive forever (reference leak)
|
| 755 |
+
}
|
| 756 |
+
|
| 757 |
+
if (module_) { // Needed by pydoc
|
| 758 |
+
setattr((PyObject *) type, "__module__", module_);
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
PYBIND11_SET_OLDPY_QUALNAME(type, qualname);
|
| 762 |
+
|
| 763 |
+
return (PyObject *) type;
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 767 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/detail/common.h
ADDED
|
@@ -0,0 +1,1287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/detail/common.h -- Basic macros
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#define PYBIND11_VERSION_MAJOR 2
|
| 13 |
+
#define PYBIND11_VERSION_MINOR 13
|
| 14 |
+
#define PYBIND11_VERSION_PATCH 6
|
| 15 |
+
|
| 16 |
+
// Similar to Python's convention: https://docs.python.org/3/c-api/apiabiversion.html
|
| 17 |
+
// Additional convention: 0xD = dev
|
| 18 |
+
#define PYBIND11_VERSION_HEX 0x020D0600
|
| 19 |
+
|
| 20 |
+
// Define some generic pybind11 helper macros for warning management.
|
| 21 |
+
//
|
| 22 |
+
// Note that compiler-specific push/pop pairs are baked into the
|
| 23 |
+
// PYBIND11_NAMESPACE_BEGIN/PYBIND11_NAMESPACE_END pair of macros. Therefore manual
|
| 24 |
+
// PYBIND11_WARNING_PUSH/PYBIND11_WARNING_POP are usually only needed in `#include` sections.
|
| 25 |
+
//
|
| 26 |
+
// If you find you need to suppress a warning, please try to make the suppression as local as
|
| 27 |
+
// possible using these macros. Please also be sure to push/pop with the pybind11 macros. Please
|
| 28 |
+
// only use compiler specifics if you need to check specific versions, e.g. Apple Clang vs. vanilla
|
| 29 |
+
// Clang.
|
| 30 |
+
#if defined(_MSC_VER)
|
| 31 |
+
# define PYBIND11_COMPILER_MSVC
|
| 32 |
+
# define PYBIND11_PRAGMA(...) __pragma(__VA_ARGS__)
|
| 33 |
+
# define PYBIND11_WARNING_PUSH PYBIND11_PRAGMA(warning(push))
|
| 34 |
+
# define PYBIND11_WARNING_POP PYBIND11_PRAGMA(warning(pop))
|
| 35 |
+
#elif defined(__INTEL_COMPILER)
|
| 36 |
+
# define PYBIND11_COMPILER_INTEL
|
| 37 |
+
# define PYBIND11_PRAGMA(...) _Pragma(#__VA_ARGS__)
|
| 38 |
+
# define PYBIND11_WARNING_PUSH PYBIND11_PRAGMA(warning push)
|
| 39 |
+
# define PYBIND11_WARNING_POP PYBIND11_PRAGMA(warning pop)
|
| 40 |
+
#elif defined(__clang__)
|
| 41 |
+
# define PYBIND11_COMPILER_CLANG
|
| 42 |
+
# define PYBIND11_PRAGMA(...) _Pragma(#__VA_ARGS__)
|
| 43 |
+
# define PYBIND11_WARNING_PUSH PYBIND11_PRAGMA(clang diagnostic push)
|
| 44 |
+
# define PYBIND11_WARNING_POP PYBIND11_PRAGMA(clang diagnostic push)
|
| 45 |
+
#elif defined(__GNUC__)
|
| 46 |
+
# define PYBIND11_COMPILER_GCC
|
| 47 |
+
# define PYBIND11_PRAGMA(...) _Pragma(#__VA_ARGS__)
|
| 48 |
+
# define PYBIND11_WARNING_PUSH PYBIND11_PRAGMA(GCC diagnostic push)
|
| 49 |
+
# define PYBIND11_WARNING_POP PYBIND11_PRAGMA(GCC diagnostic pop)
|
| 50 |
+
#endif
|
| 51 |
+
|
| 52 |
+
#ifdef PYBIND11_COMPILER_MSVC
|
| 53 |
+
# define PYBIND11_WARNING_DISABLE_MSVC(name) PYBIND11_PRAGMA(warning(disable : name))
|
| 54 |
+
#else
|
| 55 |
+
# define PYBIND11_WARNING_DISABLE_MSVC(name)
|
| 56 |
+
#endif
|
| 57 |
+
|
| 58 |
+
#ifdef PYBIND11_COMPILER_CLANG
|
| 59 |
+
# define PYBIND11_WARNING_DISABLE_CLANG(name) PYBIND11_PRAGMA(clang diagnostic ignored name)
|
| 60 |
+
#else
|
| 61 |
+
# define PYBIND11_WARNING_DISABLE_CLANG(name)
|
| 62 |
+
#endif
|
| 63 |
+
|
| 64 |
+
#ifdef PYBIND11_COMPILER_GCC
|
| 65 |
+
# define PYBIND11_WARNING_DISABLE_GCC(name) PYBIND11_PRAGMA(GCC diagnostic ignored name)
|
| 66 |
+
#else
|
| 67 |
+
# define PYBIND11_WARNING_DISABLE_GCC(name)
|
| 68 |
+
#endif
|
| 69 |
+
|
| 70 |
+
#ifdef PYBIND11_COMPILER_INTEL
|
| 71 |
+
# define PYBIND11_WARNING_DISABLE_INTEL(name) PYBIND11_PRAGMA(warning disable name)
|
| 72 |
+
#else
|
| 73 |
+
# define PYBIND11_WARNING_DISABLE_INTEL(name)
|
| 74 |
+
#endif
|
| 75 |
+
|
| 76 |
+
#define PYBIND11_NAMESPACE_BEGIN(name) \
|
| 77 |
+
namespace name { \
|
| 78 |
+
PYBIND11_WARNING_PUSH
|
| 79 |
+
|
| 80 |
+
#define PYBIND11_NAMESPACE_END(name) \
|
| 81 |
+
PYBIND11_WARNING_POP \
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// Robust support for some features and loading modules compiled against different pybind versions
|
| 85 |
+
// requires forcing hidden visibility on pybind code, so we enforce this by setting the attribute
|
| 86 |
+
// on the main `pybind11` namespace.
|
| 87 |
+
#if !defined(PYBIND11_NAMESPACE)
|
| 88 |
+
# ifdef __GNUG__
|
| 89 |
+
# define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden")))
|
| 90 |
+
# else
|
| 91 |
+
# define PYBIND11_NAMESPACE pybind11
|
| 92 |
+
# endif
|
| 93 |
+
#endif
|
| 94 |
+
|
| 95 |
+
#if !(defined(_MSC_VER) && __cplusplus == 199711L)
|
| 96 |
+
# if __cplusplus >= 201402L
|
| 97 |
+
# define PYBIND11_CPP14
|
| 98 |
+
# if __cplusplus >= 201703L
|
| 99 |
+
# define PYBIND11_CPP17
|
| 100 |
+
# if __cplusplus >= 202002L
|
| 101 |
+
# define PYBIND11_CPP20
|
| 102 |
+
// Please update tests/pybind11_tests.cpp `cpp_std()` when adding a macro here.
|
| 103 |
+
# endif
|
| 104 |
+
# endif
|
| 105 |
+
# endif
|
| 106 |
+
#elif defined(_MSC_VER) && __cplusplus == 199711L
|
| 107 |
+
// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully
|
| 108 |
+
// implemented). Unless you use the /Zc:__cplusplus flag on Visual Studio 2017 15.7 Preview 3
|
| 109 |
+
// or newer.
|
| 110 |
+
# if _MSVC_LANG >= 201402L
|
| 111 |
+
# define PYBIND11_CPP14
|
| 112 |
+
# if _MSVC_LANG > 201402L
|
| 113 |
+
# define PYBIND11_CPP17
|
| 114 |
+
# if _MSVC_LANG >= 202002L
|
| 115 |
+
# define PYBIND11_CPP20
|
| 116 |
+
# endif
|
| 117 |
+
# endif
|
| 118 |
+
# endif
|
| 119 |
+
#endif
|
| 120 |
+
|
| 121 |
+
#if defined(PYBIND11_CPP20)
|
| 122 |
+
# define PYBIND11_CONSTINIT constinit
|
| 123 |
+
# define PYBIND11_DTOR_CONSTEXPR constexpr
|
| 124 |
+
#else
|
| 125 |
+
# define PYBIND11_CONSTINIT
|
| 126 |
+
# define PYBIND11_DTOR_CONSTEXPR
|
| 127 |
+
#endif
|
| 128 |
+
|
| 129 |
+
// Compiler version assertions
|
| 130 |
+
#if defined(__INTEL_COMPILER)
|
| 131 |
+
# if __INTEL_COMPILER < 1800
|
| 132 |
+
# error pybind11 requires Intel C++ compiler v18 or newer
|
| 133 |
+
# elif __INTEL_COMPILER < 1900 && defined(PYBIND11_CPP14)
|
| 134 |
+
# error pybind11 supports only C++11 with Intel C++ compiler v18. Use v19 or newer for C++14.
|
| 135 |
+
# endif
|
| 136 |
+
/* The following pragma cannot be pop'ed:
|
| 137 |
+
https://community.intel.com/t5/Intel-C-Compiler/Inline-and-no-inline-warning/td-p/1216764 */
|
| 138 |
+
# pragma warning disable 2196 // warning #2196: routine is both "inline" and "noinline"
|
| 139 |
+
#elif defined(__clang__) && !defined(__apple_build_version__)
|
| 140 |
+
# if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 3)
|
| 141 |
+
# error pybind11 requires clang 3.3 or newer
|
| 142 |
+
# endif
|
| 143 |
+
#elif defined(__clang__)
|
| 144 |
+
// Apple changes clang version macros to its Xcode version; the first Xcode release based on
|
| 145 |
+
// (upstream) clang 3.3 was Xcode 5:
|
| 146 |
+
# if __clang_major__ < 5
|
| 147 |
+
# error pybind11 requires Xcode/clang 5.0 or newer
|
| 148 |
+
# endif
|
| 149 |
+
#elif defined(__GNUG__)
|
| 150 |
+
# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 8)
|
| 151 |
+
# error pybind11 requires gcc 4.8 or newer
|
| 152 |
+
# endif
|
| 153 |
+
#elif defined(_MSC_VER)
|
| 154 |
+
# if _MSC_VER < 1910
|
| 155 |
+
# error pybind11 2.10+ requires MSVC 2017 or newer
|
| 156 |
+
# endif
|
| 157 |
+
#endif
|
| 158 |
+
|
| 159 |
+
#if !defined(PYBIND11_EXPORT)
|
| 160 |
+
# if defined(WIN32) || defined(_WIN32)
|
| 161 |
+
# define PYBIND11_EXPORT __declspec(dllexport)
|
| 162 |
+
# else
|
| 163 |
+
# define PYBIND11_EXPORT __attribute__((visibility("default")))
|
| 164 |
+
# endif
|
| 165 |
+
#endif
|
| 166 |
+
|
| 167 |
+
#if !defined(PYBIND11_EXPORT_EXCEPTION)
|
| 168 |
+
# if defined(__apple_build_version__)
|
| 169 |
+
# define PYBIND11_EXPORT_EXCEPTION PYBIND11_EXPORT
|
| 170 |
+
# else
|
| 171 |
+
# define PYBIND11_EXPORT_EXCEPTION
|
| 172 |
+
# endif
|
| 173 |
+
#endif
|
| 174 |
+
|
| 175 |
+
// For CUDA, GCC7, GCC8:
|
| 176 |
+
// PYBIND11_NOINLINE_FORCED is incompatible with `-Wattributes -Werror`.
|
| 177 |
+
// When defining PYBIND11_NOINLINE_FORCED, it is best to also use `-Wno-attributes`.
|
| 178 |
+
// However, the measured shared-library size saving when using noinline are only
|
| 179 |
+
// 1.7% for CUDA, -0.2% for GCC7, and 0.0% for GCC8 (using -DCMAKE_BUILD_TYPE=MinSizeRel,
|
| 180 |
+
// the default under pybind11/tests).
|
| 181 |
+
#if !defined(PYBIND11_NOINLINE_FORCED) \
|
| 182 |
+
&& (defined(__CUDACC__) || (defined(__GNUC__) && (__GNUC__ == 7 || __GNUC__ == 8)))
|
| 183 |
+
# define PYBIND11_NOINLINE_DISABLED
|
| 184 |
+
#endif
|
| 185 |
+
|
| 186 |
+
// The PYBIND11_NOINLINE macro is for function DEFINITIONS.
|
| 187 |
+
// In contrast, FORWARD DECLARATIONS should never use this macro:
|
| 188 |
+
// https://stackoverflow.com/questions/9317473/forward-declaration-of-inline-functions
|
| 189 |
+
#if defined(PYBIND11_NOINLINE_DISABLED) // Option for maximum portability and experimentation.
|
| 190 |
+
# define PYBIND11_NOINLINE inline
|
| 191 |
+
#elif defined(_MSC_VER)
|
| 192 |
+
# define PYBIND11_NOINLINE __declspec(noinline) inline
|
| 193 |
+
#else
|
| 194 |
+
# define PYBIND11_NOINLINE __attribute__((noinline)) inline
|
| 195 |
+
#endif
|
| 196 |
+
|
| 197 |
+
#if defined(__MINGW32__)
|
| 198 |
+
// For unknown reasons all PYBIND11_DEPRECATED member trigger a warning when declared
|
| 199 |
+
// whether it is used or not
|
| 200 |
+
# define PYBIND11_DEPRECATED(reason)
|
| 201 |
+
#elif defined(PYBIND11_CPP14)
|
| 202 |
+
# define PYBIND11_DEPRECATED(reason) [[deprecated(reason)]]
|
| 203 |
+
#else
|
| 204 |
+
# define PYBIND11_DEPRECATED(reason) __attribute__((deprecated(reason)))
|
| 205 |
+
#endif
|
| 206 |
+
|
| 207 |
+
#if defined(PYBIND11_CPP17)
|
| 208 |
+
# define PYBIND11_MAYBE_UNUSED [[maybe_unused]]
|
| 209 |
+
#elif defined(_MSC_VER) && !defined(__clang__)
|
| 210 |
+
# define PYBIND11_MAYBE_UNUSED
|
| 211 |
+
#else
|
| 212 |
+
# define PYBIND11_MAYBE_UNUSED __attribute__((__unused__))
|
| 213 |
+
#endif
|
| 214 |
+
|
| 215 |
+
/* Don't let Python.h #define (v)snprintf as macro because they are implemented
|
| 216 |
+
properly in Visual Studio since 2015. */
|
| 217 |
+
#if defined(_MSC_VER)
|
| 218 |
+
# define HAVE_SNPRINTF 1
|
| 219 |
+
#endif
|
| 220 |
+
|
| 221 |
+
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
|
| 222 |
+
#if defined(_MSC_VER)
|
| 223 |
+
PYBIND11_WARNING_PUSH
|
| 224 |
+
PYBIND11_WARNING_DISABLE_MSVC(4505)
|
| 225 |
+
// C4505: 'PySlice_GetIndicesEx': unreferenced local function has been removed (PyPy only)
|
| 226 |
+
# if defined(_DEBUG) && !defined(Py_DEBUG)
|
| 227 |
+
// Workaround for a VS 2022 issue.
|
| 228 |
+
// NOTE: This workaround knowingly violates the Python.h include order requirement:
|
| 229 |
+
// https://docs.python.org/3/c-api/intro.html#include-files
|
| 230 |
+
// See https://github.com/pybind/pybind11/pull/3497 for full context.
|
| 231 |
+
# include <yvals.h>
|
| 232 |
+
# if _MSVC_STL_VERSION >= 143
|
| 233 |
+
# include <crtdefs.h>
|
| 234 |
+
# endif
|
| 235 |
+
# define PYBIND11_DEBUG_MARKER
|
| 236 |
+
# undef _DEBUG
|
| 237 |
+
# endif
|
| 238 |
+
#endif
|
| 239 |
+
|
| 240 |
+
// https://en.cppreference.com/w/c/chrono/localtime
|
| 241 |
+
#if defined(__STDC_LIB_EXT1__) && !defined(__STDC_WANT_LIB_EXT1__)
|
| 242 |
+
# define __STDC_WANT_LIB_EXT1__
|
| 243 |
+
#endif
|
| 244 |
+
|
| 245 |
+
#ifdef __has_include
|
| 246 |
+
// std::optional (but including it in c++14 mode isn't allowed)
|
| 247 |
+
# if defined(PYBIND11_CPP17) && __has_include(<optional>)
|
| 248 |
+
# define PYBIND11_HAS_OPTIONAL 1
|
| 249 |
+
# endif
|
| 250 |
+
// std::experimental::optional (but not allowed in c++11 mode)
|
| 251 |
+
# if defined(PYBIND11_CPP14) && (__has_include(<experimental/optional>) && \
|
| 252 |
+
!__has_include(<optional>))
|
| 253 |
+
# define PYBIND11_HAS_EXP_OPTIONAL 1
|
| 254 |
+
# endif
|
| 255 |
+
// std::variant
|
| 256 |
+
# if defined(PYBIND11_CPP17) && __has_include(<variant>)
|
| 257 |
+
# define PYBIND11_HAS_VARIANT 1
|
| 258 |
+
# endif
|
| 259 |
+
#elif defined(_MSC_VER) && defined(PYBIND11_CPP17)
|
| 260 |
+
# define PYBIND11_HAS_OPTIONAL 1
|
| 261 |
+
# define PYBIND11_HAS_VARIANT 1
|
| 262 |
+
#endif
|
| 263 |
+
|
| 264 |
+
#if defined(PYBIND11_CPP17)
|
| 265 |
+
# if defined(__has_include)
|
| 266 |
+
# if __has_include(<string_view>)
|
| 267 |
+
# define PYBIND11_HAS_STRING_VIEW
|
| 268 |
+
# endif
|
| 269 |
+
# elif defined(_MSC_VER)
|
| 270 |
+
# define PYBIND11_HAS_STRING_VIEW
|
| 271 |
+
# endif
|
| 272 |
+
#endif
|
| 273 |
+
|
| 274 |
+
#include <Python.h>
|
| 275 |
+
#if PY_VERSION_HEX < 0x03070000
|
| 276 |
+
# error "PYTHON < 3.7 IS UNSUPPORTED. pybind11 v2.12 was the last to support Python 3.6."
|
| 277 |
+
#endif
|
| 278 |
+
#include <frameobject.h>
|
| 279 |
+
#include <pythread.h>
|
| 280 |
+
|
| 281 |
+
/* Python #defines overrides on all sorts of core functions, which
|
| 282 |
+
tends to weak havok in C++ codebases that expect these to work
|
| 283 |
+
like regular functions (potentially with several overloads) */
|
| 284 |
+
#if defined(isalnum)
|
| 285 |
+
# undef isalnum
|
| 286 |
+
# undef isalpha
|
| 287 |
+
# undef islower
|
| 288 |
+
# undef isspace
|
| 289 |
+
# undef isupper
|
| 290 |
+
# undef tolower
|
| 291 |
+
# undef toupper
|
| 292 |
+
#endif
|
| 293 |
+
|
| 294 |
+
#if defined(copysign)
|
| 295 |
+
# undef copysign
|
| 296 |
+
#endif
|
| 297 |
+
|
| 298 |
+
#if defined(PYBIND11_NUMPY_1_ONLY)
|
| 299 |
+
# define PYBIND11_INTERNAL_NUMPY_1_ONLY_DETECTED
|
| 300 |
+
#endif
|
| 301 |
+
|
| 302 |
+
#if defined(PYPY_VERSION) && !defined(PYBIND11_SIMPLE_GIL_MANAGEMENT)
|
| 303 |
+
# define PYBIND11_SIMPLE_GIL_MANAGEMENT
|
| 304 |
+
#endif
|
| 305 |
+
|
| 306 |
+
#if defined(_MSC_VER)
|
| 307 |
+
# if defined(PYBIND11_DEBUG_MARKER)
|
| 308 |
+
# define _DEBUG
|
| 309 |
+
# undef PYBIND11_DEBUG_MARKER
|
| 310 |
+
# endif
|
| 311 |
+
PYBIND11_WARNING_POP
|
| 312 |
+
#endif
|
| 313 |
+
|
| 314 |
+
#include <cstddef>
|
| 315 |
+
#include <cstring>
|
| 316 |
+
#include <exception>
|
| 317 |
+
#include <forward_list>
|
| 318 |
+
#include <memory>
|
| 319 |
+
#include <stdexcept>
|
| 320 |
+
#include <string>
|
| 321 |
+
#include <type_traits>
|
| 322 |
+
#include <typeindex>
|
| 323 |
+
#include <unordered_map>
|
| 324 |
+
#include <unordered_set>
|
| 325 |
+
#include <vector>
|
| 326 |
+
#if defined(__has_include)
|
| 327 |
+
# if __has_include(<version>)
|
| 328 |
+
# include <version>
|
| 329 |
+
# endif
|
| 330 |
+
#endif
|
| 331 |
+
|
| 332 |
+
// Must be after including <version> or one of the other headers specified by the standard
|
| 333 |
+
#if defined(__cpp_lib_char8_t) && __cpp_lib_char8_t >= 201811L
|
| 334 |
+
# define PYBIND11_HAS_U8STRING
|
| 335 |
+
#endif
|
| 336 |
+
|
| 337 |
+
// See description of PR #4246:
|
| 338 |
+
#if !defined(PYBIND11_NO_ASSERT_GIL_HELD_INCREF_DECREF) && !defined(NDEBUG) \
|
| 339 |
+
&& !defined(PYPY_VERSION) && !defined(PYBIND11_ASSERT_GIL_HELD_INCREF_DECREF)
|
| 340 |
+
# define PYBIND11_ASSERT_GIL_HELD_INCREF_DECREF
|
| 341 |
+
#endif
|
| 342 |
+
|
| 343 |
+
// #define PYBIND11_STR_LEGACY_PERMISSIVE
|
| 344 |
+
// If DEFINED, pybind11::str can hold PyUnicodeObject or PyBytesObject
|
| 345 |
+
// (probably surprising and never documented, but this was the
|
| 346 |
+
// legacy behavior until and including v2.6.x). As a side-effect,
|
| 347 |
+
// pybind11::isinstance<str>() is true for both pybind11::str and
|
| 348 |
+
// pybind11::bytes.
|
| 349 |
+
// If UNDEFINED, pybind11::str can only hold PyUnicodeObject, and
|
| 350 |
+
// pybind11::isinstance<str>() is true only for pybind11::str.
|
| 351 |
+
// However, for Python 2 only (!), the pybind11::str caster
|
| 352 |
+
// implicitly decoded bytes to PyUnicodeObject. This was to ease
|
| 353 |
+
// the transition from the legacy behavior to the non-permissive
|
| 354 |
+
// behavior.
|
| 355 |
+
|
| 356 |
+
/// Compatibility macros for Python 2 / Python 3 versions TODO: remove
|
| 357 |
+
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr)
|
| 358 |
+
#define PYBIND11_INSTANCE_METHOD_CHECK PyInstanceMethod_Check
|
| 359 |
+
#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyInstanceMethod_GET_FUNCTION
|
| 360 |
+
#define PYBIND11_BYTES_CHECK PyBytes_Check
|
| 361 |
+
#define PYBIND11_BYTES_FROM_STRING PyBytes_FromString
|
| 362 |
+
#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyBytes_FromStringAndSize
|
| 363 |
+
#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyBytes_AsStringAndSize
|
| 364 |
+
#define PYBIND11_BYTES_AS_STRING PyBytes_AsString
|
| 365 |
+
#define PYBIND11_BYTES_SIZE PyBytes_Size
|
| 366 |
+
#define PYBIND11_LONG_CHECK(o) PyLong_Check(o)
|
| 367 |
+
#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o)
|
| 368 |
+
#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) (o))
|
| 369 |
+
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) (o))
|
| 370 |
+
#define PYBIND11_BYTES_NAME "bytes"
|
| 371 |
+
#define PYBIND11_STRING_NAME "str"
|
| 372 |
+
#define PYBIND11_SLICE_OBJECT PyObject
|
| 373 |
+
#define PYBIND11_FROM_STRING PyUnicode_FromString
|
| 374 |
+
#define PYBIND11_STR_TYPE ::pybind11::str
|
| 375 |
+
#define PYBIND11_BOOL_ATTR "__bool__"
|
| 376 |
+
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
|
| 377 |
+
#define PYBIND11_BUILTINS_MODULE "builtins"
|
| 378 |
+
// Providing a separate declaration to make Clang's -Wmissing-prototypes happy.
|
| 379 |
+
// See comment for PYBIND11_MODULE below for why this is marked "maybe unused".
|
| 380 |
+
#define PYBIND11_PLUGIN_IMPL(name) \
|
| 381 |
+
extern "C" PYBIND11_MAYBE_UNUSED PYBIND11_EXPORT PyObject *PyInit_##name(); \
|
| 382 |
+
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
|
| 383 |
+
|
| 384 |
+
#define PYBIND11_TRY_NEXT_OVERLOAD ((PyObject *) 1) // special failure return code
|
| 385 |
+
#define PYBIND11_STRINGIFY(x) #x
|
| 386 |
+
#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x)
|
| 387 |
+
#define PYBIND11_CONCAT(first, second) first##second
|
| 388 |
+
#define PYBIND11_ENSURE_INTERNALS_READY pybind11::detail::get_internals();
|
| 389 |
+
|
| 390 |
+
#define PYBIND11_CHECK_PYTHON_VERSION \
|
| 391 |
+
{ \
|
| 392 |
+
const char *compiled_ver \
|
| 393 |
+
= PYBIND11_TOSTRING(PY_MAJOR_VERSION) "." PYBIND11_TOSTRING(PY_MINOR_VERSION); \
|
| 394 |
+
const char *runtime_ver = Py_GetVersion(); \
|
| 395 |
+
size_t len = std::strlen(compiled_ver); \
|
| 396 |
+
if (std::strncmp(runtime_ver, compiled_ver, len) != 0 \
|
| 397 |
+
|| (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { \
|
| 398 |
+
PyErr_Format(PyExc_ImportError, \
|
| 399 |
+
"Python version mismatch: module was compiled for Python %s, " \
|
| 400 |
+
"but the interpreter version is incompatible: %s.", \
|
| 401 |
+
compiled_ver, \
|
| 402 |
+
runtime_ver); \
|
| 403 |
+
return nullptr; \
|
| 404 |
+
} \
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
#define PYBIND11_CATCH_INIT_EXCEPTIONS \
|
| 408 |
+
catch (pybind11::error_already_set & e) { \
|
| 409 |
+
pybind11::raise_from(e, PyExc_ImportError, "initialization failed"); \
|
| 410 |
+
return nullptr; \
|
| 411 |
+
} \
|
| 412 |
+
catch (const std::exception &e) { \
|
| 413 |
+
::pybind11::set_error(PyExc_ImportError, e.what()); \
|
| 414 |
+
return nullptr; \
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
/** \rst
|
| 418 |
+
***Deprecated in favor of PYBIND11_MODULE***
|
| 419 |
+
|
| 420 |
+
This macro creates the entry point that will be invoked when the Python interpreter
|
| 421 |
+
imports a plugin library. Please create a `module_` in the function body and return
|
| 422 |
+
the pointer to its underlying Python object at the end.
|
| 423 |
+
|
| 424 |
+
.. code-block:: cpp
|
| 425 |
+
|
| 426 |
+
PYBIND11_PLUGIN(example) {
|
| 427 |
+
pybind11::module_ m("example", "pybind11 example plugin");
|
| 428 |
+
/// Set up bindings here
|
| 429 |
+
return m.ptr();
|
| 430 |
+
}
|
| 431 |
+
\endrst */
|
| 432 |
+
#define PYBIND11_PLUGIN(name) \
|
| 433 |
+
PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \
|
| 434 |
+
static PyObject *pybind11_init(); \
|
| 435 |
+
PYBIND11_PLUGIN_IMPL(name) { \
|
| 436 |
+
PYBIND11_CHECK_PYTHON_VERSION \
|
| 437 |
+
PYBIND11_ENSURE_INTERNALS_READY \
|
| 438 |
+
try { \
|
| 439 |
+
return pybind11_init(); \
|
| 440 |
+
} \
|
| 441 |
+
PYBIND11_CATCH_INIT_EXCEPTIONS \
|
| 442 |
+
} \
|
| 443 |
+
PyObject *pybind11_init()
|
| 444 |
+
|
| 445 |
+
/** \rst
|
| 446 |
+
This macro creates the entry point that will be invoked when the Python interpreter
|
| 447 |
+
imports an extension module. The module name is given as the first argument and it
|
| 448 |
+
should not be in quotes. The second macro argument defines a variable of type
|
| 449 |
+
`py::module_` which can be used to initialize the module.
|
| 450 |
+
|
| 451 |
+
The entry point is marked as "maybe unused" to aid dead-code detection analysis:
|
| 452 |
+
since the entry point is typically only looked up at runtime and not referenced
|
| 453 |
+
during translation, it would otherwise appear as unused ("dead") code.
|
| 454 |
+
|
| 455 |
+
.. code-block:: cpp
|
| 456 |
+
|
| 457 |
+
PYBIND11_MODULE(example, m) {
|
| 458 |
+
m.doc() = "pybind11 example module";
|
| 459 |
+
|
| 460 |
+
// Add bindings here
|
| 461 |
+
m.def("foo", []() {
|
| 462 |
+
return "Hello, World!";
|
| 463 |
+
});
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
The third macro argument is optional (available since 2.13.0), and can be used to
|
| 467 |
+
mark the extension module as safe to run without the GIL under a free-threaded CPython
|
| 468 |
+
interpreter. Passing this argument has no effect on other interpreters.
|
| 469 |
+
|
| 470 |
+
.. code-block:: cpp
|
| 471 |
+
|
| 472 |
+
PYBIND11_MODULE(example, m, py::mod_gil_not_used()) {
|
| 473 |
+
m.doc() = "pybind11 example module safe to run without the GIL";
|
| 474 |
+
|
| 475 |
+
// Add bindings here
|
| 476 |
+
m.def("foo", []() {
|
| 477 |
+
return "Hello, Free-threaded World!";
|
| 478 |
+
});
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
\endrst */
|
| 482 |
+
PYBIND11_WARNING_PUSH
|
| 483 |
+
PYBIND11_WARNING_DISABLE_CLANG("-Wgnu-zero-variadic-macro-arguments")
|
| 484 |
+
#define PYBIND11_MODULE(name, variable, ...) \
|
| 485 |
+
static ::pybind11::module_::module_def PYBIND11_CONCAT(pybind11_module_def_, name) \
|
| 486 |
+
PYBIND11_MAYBE_UNUSED; \
|
| 487 |
+
PYBIND11_MAYBE_UNUSED \
|
| 488 |
+
static void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ &); \
|
| 489 |
+
PYBIND11_PLUGIN_IMPL(name) { \
|
| 490 |
+
PYBIND11_CHECK_PYTHON_VERSION \
|
| 491 |
+
PYBIND11_ENSURE_INTERNALS_READY \
|
| 492 |
+
auto m = ::pybind11::module_::create_extension_module( \
|
| 493 |
+
PYBIND11_TOSTRING(name), \
|
| 494 |
+
nullptr, \
|
| 495 |
+
&PYBIND11_CONCAT(pybind11_module_def_, name), \
|
| 496 |
+
##__VA_ARGS__); \
|
| 497 |
+
try { \
|
| 498 |
+
PYBIND11_CONCAT(pybind11_init_, name)(m); \
|
| 499 |
+
return m.ptr(); \
|
| 500 |
+
} \
|
| 501 |
+
PYBIND11_CATCH_INIT_EXCEPTIONS \
|
| 502 |
+
} \
|
| 503 |
+
void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ & (variable))
|
| 504 |
+
PYBIND11_WARNING_POP
|
| 505 |
+
|
| 506 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 507 |
+
|
| 508 |
+
using ssize_t = Py_ssize_t;
|
| 509 |
+
using size_t = std::size_t;
|
| 510 |
+
|
| 511 |
+
template <typename IntType>
|
| 512 |
+
inline ssize_t ssize_t_cast(const IntType &val) {
|
| 513 |
+
static_assert(sizeof(IntType) <= sizeof(ssize_t), "Implicit narrowing is not permitted.");
|
| 514 |
+
return static_cast<ssize_t>(val);
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
/// Approach used to cast a previously unknown C++ instance into a Python object
|
| 518 |
+
enum class return_value_policy : uint8_t {
|
| 519 |
+
/** This is the default return value policy, which falls back to the policy
|
| 520 |
+
return_value_policy::take_ownership when the return value is a pointer.
|
| 521 |
+
Otherwise, it uses return_value::move or return_value::copy for rvalue
|
| 522 |
+
and lvalue references, respectively. See below for a description of what
|
| 523 |
+
all of these different policies do. */
|
| 524 |
+
automatic = 0,
|
| 525 |
+
|
| 526 |
+
/** As above, but use policy return_value_policy::reference when the return
|
| 527 |
+
value is a pointer. This is the default conversion policy for function
|
| 528 |
+
arguments when calling Python functions manually from C++ code (i.e. via
|
| 529 |
+
handle::operator()). You probably won't need to use this. */
|
| 530 |
+
automatic_reference,
|
| 531 |
+
|
| 532 |
+
/** Reference an existing object (i.e. do not create a new copy) and take
|
| 533 |
+
ownership. Python will call the destructor and delete operator when the
|
| 534 |
+
object's reference count reaches zero. Undefined behavior ensues when
|
| 535 |
+
the C++ side does the same.. */
|
| 536 |
+
take_ownership,
|
| 537 |
+
|
| 538 |
+
/** Create a new copy of the returned object, which will be owned by
|
| 539 |
+
Python. This policy is comparably safe because the lifetimes of the two
|
| 540 |
+
instances are decoupled. */
|
| 541 |
+
copy,
|
| 542 |
+
|
| 543 |
+
/** Use std::move to move the return value contents into a new instance
|
| 544 |
+
that will be owned by Python. This policy is comparably safe because the
|
| 545 |
+
lifetimes of the two instances (move source and destination) are
|
| 546 |
+
decoupled. */
|
| 547 |
+
move,
|
| 548 |
+
|
| 549 |
+
/** Reference an existing object, but do not take ownership. The C++ side
|
| 550 |
+
is responsible for managing the object's lifetime and deallocating it
|
| 551 |
+
when it is no longer used. Warning: undefined behavior will ensue when
|
| 552 |
+
the C++ side deletes an object that is still referenced and used by
|
| 553 |
+
Python. */
|
| 554 |
+
reference,
|
| 555 |
+
|
| 556 |
+
/** This policy only applies to methods and properties. It references the
|
| 557 |
+
object without taking ownership similar to the above
|
| 558 |
+
return_value_policy::reference policy. In contrast to that policy, the
|
| 559 |
+
function or property's implicit this argument (called the parent) is
|
| 560 |
+
considered to be the owner of the return value (the child).
|
| 561 |
+
pybind11 then couples the lifetime of the parent to the child via a
|
| 562 |
+
reference relationship that ensures that the parent cannot be garbage
|
| 563 |
+
collected while Python is still using the child. More advanced
|
| 564 |
+
variations of this scheme are also possible using combinations of
|
| 565 |
+
return_value_policy::reference and the keep_alive call policy */
|
| 566 |
+
reference_internal
|
| 567 |
+
};
|
| 568 |
+
|
| 569 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 570 |
+
|
| 571 |
+
inline static constexpr int log2(size_t n, int k = 0) {
|
| 572 |
+
return (n <= 1) ? k : log2(n >> 1, k + 1);
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
// Returns the size as a multiple of sizeof(void *), rounded up.
|
| 576 |
+
inline static constexpr size_t size_in_ptrs(size_t s) {
|
| 577 |
+
return 1 + ((s - 1) >> log2(sizeof(void *)));
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
/**
|
| 581 |
+
* The space to allocate for simple layout instance holders (see below) in multiple of the size of
|
| 582 |
+
* a pointer (e.g. 2 means 16 bytes on 64-bit architectures). The default is the minimum required
|
| 583 |
+
* to holder either a std::unique_ptr or std::shared_ptr (which is almost always
|
| 584 |
+
* sizeof(std::shared_ptr<T>)).
|
| 585 |
+
*/
|
| 586 |
+
constexpr size_t instance_simple_holder_in_ptrs() {
|
| 587 |
+
static_assert(sizeof(std::shared_ptr<int>) >= sizeof(std::unique_ptr<int>),
|
| 588 |
+
"pybind assumes std::shared_ptrs are at least as big as std::unique_ptrs");
|
| 589 |
+
return size_in_ptrs(sizeof(std::shared_ptr<int>));
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
// Forward declarations
|
| 593 |
+
struct type_info;
|
| 594 |
+
struct value_and_holder;
|
| 595 |
+
|
| 596 |
+
struct nonsimple_values_and_holders {
|
| 597 |
+
void **values_and_holders;
|
| 598 |
+
uint8_t *status;
|
| 599 |
+
};
|
| 600 |
+
|
| 601 |
+
/// The 'instance' type which needs to be standard layout (need to be able to use 'offsetof')
|
| 602 |
+
struct instance {
|
| 603 |
+
PyObject_HEAD
|
| 604 |
+
/// Storage for pointers and holder; see simple_layout, below, for a description
|
| 605 |
+
union {
|
| 606 |
+
void *simple_value_holder[1 + instance_simple_holder_in_ptrs()];
|
| 607 |
+
nonsimple_values_and_holders nonsimple;
|
| 608 |
+
};
|
| 609 |
+
/// Weak references
|
| 610 |
+
PyObject *weakrefs;
|
| 611 |
+
/// If true, the pointer is owned which means we're free to manage it with a holder.
|
| 612 |
+
bool owned : 1;
|
| 613 |
+
/**
|
| 614 |
+
* An instance has two possible value/holder layouts.
|
| 615 |
+
*
|
| 616 |
+
* Simple layout (when this flag is true), means the `simple_value_holder` is set with a
|
| 617 |
+
* pointer and the holder object governing that pointer, i.e. [val1*][holder]. This layout is
|
| 618 |
+
* applied whenever there is no python-side multiple inheritance of bound C++ types *and* the
|
| 619 |
+
* type's holder will fit in the default space (which is large enough to hold either a
|
| 620 |
+
* std::unique_ptr or std::shared_ptr).
|
| 621 |
+
*
|
| 622 |
+
* Non-simple layout applies when using custom holders that require more space than
|
| 623 |
+
* `shared_ptr` (which is typically the size of two pointers), or when multiple inheritance is
|
| 624 |
+
* used on the python side. Non-simple layout allocates the required amount of memory to have
|
| 625 |
+
* multiple bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is
|
| 626 |
+
* set to a pointer to allocated space of the required space to hold a sequence of value
|
| 627 |
+
* pointers and holders followed `status`, a set of bit flags (1 byte each), i.e.
|
| 628 |
+
* [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple
|
| 629 |
+
* of `sizeof(void *)`. `nonsimple.status` is, for convenience, a pointer to the beginning of
|
| 630 |
+
* the [bb...] block (but not independently allocated).
|
| 631 |
+
*
|
| 632 |
+
* Status bits indicate whether the associated holder is constructed (&
|
| 633 |
+
* status_holder_constructed) and whether the value pointer is registered (&
|
| 634 |
+
* status_instance_registered) in `registered_instances`.
|
| 635 |
+
*/
|
| 636 |
+
bool simple_layout : 1;
|
| 637 |
+
/// For simple layout, tracks whether the holder has been constructed
|
| 638 |
+
bool simple_holder_constructed : 1;
|
| 639 |
+
/// For simple layout, tracks whether the instance is registered in `registered_instances`
|
| 640 |
+
bool simple_instance_registered : 1;
|
| 641 |
+
/// If true, get_internals().patients has an entry for this object
|
| 642 |
+
bool has_patients : 1;
|
| 643 |
+
|
| 644 |
+
/// Initializes all of the above type/values/holders data (but not the instance values
|
| 645 |
+
/// themselves)
|
| 646 |
+
void allocate_layout();
|
| 647 |
+
|
| 648 |
+
/// Destroys/deallocates all of the above
|
| 649 |
+
void deallocate_layout();
|
| 650 |
+
|
| 651 |
+
/// Returns the value_and_holder wrapper for the given type (or the first, if `find_type`
|
| 652 |
+
/// omitted). Returns a default-constructed (with `.inst = nullptr`) object on failure if
|
| 653 |
+
/// `throw_if_missing` is false.
|
| 654 |
+
value_and_holder get_value_and_holder(const type_info *find_type = nullptr,
|
| 655 |
+
bool throw_if_missing = true);
|
| 656 |
+
|
| 657 |
+
/// Bit values for the non-simple status flags
|
| 658 |
+
static constexpr uint8_t status_holder_constructed = 1;
|
| 659 |
+
static constexpr uint8_t status_instance_registered = 2;
|
| 660 |
+
};
|
| 661 |
+
|
| 662 |
+
static_assert(std::is_standard_layout<instance>::value,
|
| 663 |
+
"Internal error: `pybind11::detail::instance` is not standard layout!");
|
| 664 |
+
|
| 665 |
+
/// from __cpp_future__ import (convenient aliases from C++14/17)
|
| 666 |
+
#if defined(PYBIND11_CPP14)
|
| 667 |
+
using std::conditional_t;
|
| 668 |
+
using std::enable_if_t;
|
| 669 |
+
using std::remove_cv_t;
|
| 670 |
+
using std::remove_reference_t;
|
| 671 |
+
#else
|
| 672 |
+
template <bool B, typename T = void>
|
| 673 |
+
using enable_if_t = typename std::enable_if<B, T>::type;
|
| 674 |
+
template <bool B, typename T, typename F>
|
| 675 |
+
using conditional_t = typename std::conditional<B, T, F>::type;
|
| 676 |
+
template <typename T>
|
| 677 |
+
using remove_cv_t = typename std::remove_cv<T>::type;
|
| 678 |
+
template <typename T>
|
| 679 |
+
using remove_reference_t = typename std::remove_reference<T>::type;
|
| 680 |
+
#endif
|
| 681 |
+
|
| 682 |
+
#if defined(PYBIND11_CPP20)
|
| 683 |
+
using std::remove_cvref;
|
| 684 |
+
using std::remove_cvref_t;
|
| 685 |
+
#else
|
| 686 |
+
template <class T>
|
| 687 |
+
struct remove_cvref {
|
| 688 |
+
using type = remove_cv_t<remove_reference_t<T>>;
|
| 689 |
+
};
|
| 690 |
+
template <class T>
|
| 691 |
+
using remove_cvref_t = typename remove_cvref<T>::type;
|
| 692 |
+
#endif
|
| 693 |
+
|
| 694 |
+
/// Example usage: is_same_ignoring_cvref<T, PyObject *>::value
|
| 695 |
+
template <typename T, typename U>
|
| 696 |
+
using is_same_ignoring_cvref = std::is_same<detail::remove_cvref_t<T>, U>;
|
| 697 |
+
|
| 698 |
+
/// Index sequences
|
| 699 |
+
#if defined(PYBIND11_CPP14)
|
| 700 |
+
using std::index_sequence;
|
| 701 |
+
using std::make_index_sequence;
|
| 702 |
+
#else
|
| 703 |
+
template <size_t...>
|
| 704 |
+
struct index_sequence {};
|
| 705 |
+
template <size_t N, size_t... S>
|
| 706 |
+
struct make_index_sequence_impl : make_index_sequence_impl<N - 1, N - 1, S...> {};
|
| 707 |
+
template <size_t... S>
|
| 708 |
+
struct make_index_sequence_impl<0, S...> {
|
| 709 |
+
using type = index_sequence<S...>;
|
| 710 |
+
};
|
| 711 |
+
template <size_t N>
|
| 712 |
+
using make_index_sequence = typename make_index_sequence_impl<N>::type;
|
| 713 |
+
#endif
|
| 714 |
+
|
| 715 |
+
/// Make an index sequence of the indices of true arguments
|
| 716 |
+
template <typename ISeq, size_t, bool...>
|
| 717 |
+
struct select_indices_impl {
|
| 718 |
+
using type = ISeq;
|
| 719 |
+
};
|
| 720 |
+
template <size_t... IPrev, size_t I, bool B, bool... Bs>
|
| 721 |
+
struct select_indices_impl<index_sequence<IPrev...>, I, B, Bs...>
|
| 722 |
+
: select_indices_impl<conditional_t<B, index_sequence<IPrev..., I>, index_sequence<IPrev...>>,
|
| 723 |
+
I + 1,
|
| 724 |
+
Bs...> {};
|
| 725 |
+
template <bool... Bs>
|
| 726 |
+
using select_indices = typename select_indices_impl<index_sequence<>, 0, Bs...>::type;
|
| 727 |
+
|
| 728 |
+
/// Backports of std::bool_constant and std::negation to accommodate older compilers
|
| 729 |
+
template <bool B>
|
| 730 |
+
using bool_constant = std::integral_constant<bool, B>;
|
| 731 |
+
template <typename T>
|
| 732 |
+
struct negation : bool_constant<!T::value> {};
|
| 733 |
+
|
| 734 |
+
// PGI/Intel cannot detect operator delete with the "compatible" void_t impl, so
|
| 735 |
+
// using the new one (C++14 defect, so generally works on newer compilers, even
|
| 736 |
+
// if not in C++17 mode)
|
| 737 |
+
#if defined(__PGIC__) || defined(__INTEL_COMPILER)
|
| 738 |
+
template <typename...>
|
| 739 |
+
using void_t = void;
|
| 740 |
+
#else
|
| 741 |
+
template <typename...>
|
| 742 |
+
struct void_t_impl {
|
| 743 |
+
using type = void;
|
| 744 |
+
};
|
| 745 |
+
template <typename... Ts>
|
| 746 |
+
using void_t = typename void_t_impl<Ts...>::type;
|
| 747 |
+
#endif
|
| 748 |
+
|
| 749 |
+
/// Compile-time all/any/none of that check the boolean value of all template types
|
| 750 |
+
#if defined(__cpp_fold_expressions) && !(defined(_MSC_VER) && (_MSC_VER < 1916))
|
| 751 |
+
template <class... Ts>
|
| 752 |
+
using all_of = bool_constant<(Ts::value && ...)>;
|
| 753 |
+
template <class... Ts>
|
| 754 |
+
using any_of = bool_constant<(Ts::value || ...)>;
|
| 755 |
+
#elif !defined(_MSC_VER)
|
| 756 |
+
template <bool...>
|
| 757 |
+
struct bools {};
|
| 758 |
+
template <class... Ts>
|
| 759 |
+
using all_of = std::is_same<bools<Ts::value..., true>, bools<true, Ts::value...>>;
|
| 760 |
+
template <class... Ts>
|
| 761 |
+
using any_of = negation<all_of<negation<Ts>...>>;
|
| 762 |
+
#else
|
| 763 |
+
// MSVC has trouble with the above, but supports std::conjunction, which we can use instead (albeit
|
| 764 |
+
// at a slight loss of compilation efficiency).
|
| 765 |
+
template <class... Ts>
|
| 766 |
+
using all_of = std::conjunction<Ts...>;
|
| 767 |
+
template <class... Ts>
|
| 768 |
+
using any_of = std::disjunction<Ts...>;
|
| 769 |
+
#endif
|
| 770 |
+
template <class... Ts>
|
| 771 |
+
using none_of = negation<any_of<Ts...>>;
|
| 772 |
+
|
| 773 |
+
template <class T, template <class> class... Predicates>
|
| 774 |
+
using satisfies_all_of = all_of<Predicates<T>...>;
|
| 775 |
+
template <class T, template <class> class... Predicates>
|
| 776 |
+
using satisfies_any_of = any_of<Predicates<T>...>;
|
| 777 |
+
template <class T, template <class> class... Predicates>
|
| 778 |
+
using satisfies_none_of = none_of<Predicates<T>...>;
|
| 779 |
+
|
| 780 |
+
/// Strip the class from a method type
|
| 781 |
+
template <typename T>
|
| 782 |
+
struct remove_class {};
|
| 783 |
+
template <typename C, typename R, typename... A>
|
| 784 |
+
struct remove_class<R (C::*)(A...)> {
|
| 785 |
+
using type = R(A...);
|
| 786 |
+
};
|
| 787 |
+
template <typename C, typename R, typename... A>
|
| 788 |
+
struct remove_class<R (C::*)(A...) const> {
|
| 789 |
+
using type = R(A...);
|
| 790 |
+
};
|
| 791 |
+
#ifdef __cpp_noexcept_function_type
|
| 792 |
+
template <typename C, typename R, typename... A>
|
| 793 |
+
struct remove_class<R (C::*)(A...) noexcept> {
|
| 794 |
+
using type = R(A...);
|
| 795 |
+
};
|
| 796 |
+
template <typename C, typename R, typename... A>
|
| 797 |
+
struct remove_class<R (C::*)(A...) const noexcept> {
|
| 798 |
+
using type = R(A...);
|
| 799 |
+
};
|
| 800 |
+
#endif
|
| 801 |
+
/// Helper template to strip away type modifiers
|
| 802 |
+
template <typename T>
|
| 803 |
+
struct intrinsic_type {
|
| 804 |
+
using type = T;
|
| 805 |
+
};
|
| 806 |
+
template <typename T>
|
| 807 |
+
struct intrinsic_type<const T> {
|
| 808 |
+
using type = typename intrinsic_type<T>::type;
|
| 809 |
+
};
|
| 810 |
+
template <typename T>
|
| 811 |
+
struct intrinsic_type<T *> {
|
| 812 |
+
using type = typename intrinsic_type<T>::type;
|
| 813 |
+
};
|
| 814 |
+
template <typename T>
|
| 815 |
+
struct intrinsic_type<T &> {
|
| 816 |
+
using type = typename intrinsic_type<T>::type;
|
| 817 |
+
};
|
| 818 |
+
template <typename T>
|
| 819 |
+
struct intrinsic_type<T &&> {
|
| 820 |
+
using type = typename intrinsic_type<T>::type;
|
| 821 |
+
};
|
| 822 |
+
template <typename T, size_t N>
|
| 823 |
+
struct intrinsic_type<const T[N]> {
|
| 824 |
+
using type = typename intrinsic_type<T>::type;
|
| 825 |
+
};
|
| 826 |
+
template <typename T, size_t N>
|
| 827 |
+
struct intrinsic_type<T[N]> {
|
| 828 |
+
using type = typename intrinsic_type<T>::type;
|
| 829 |
+
};
|
| 830 |
+
template <typename T>
|
| 831 |
+
using intrinsic_t = typename intrinsic_type<T>::type;
|
| 832 |
+
|
| 833 |
+
/// Helper type to replace 'void' in some expressions
|
| 834 |
+
struct void_type {};
|
| 835 |
+
|
| 836 |
+
/// Helper template which holds a list of types
|
| 837 |
+
template <typename...>
|
| 838 |
+
struct type_list {};
|
| 839 |
+
|
| 840 |
+
/// Compile-time integer sum
|
| 841 |
+
#ifdef __cpp_fold_expressions
|
| 842 |
+
template <typename... Ts>
|
| 843 |
+
constexpr size_t constexpr_sum(Ts... ns) {
|
| 844 |
+
return (0 + ... + size_t{ns});
|
| 845 |
+
}
|
| 846 |
+
#else
|
| 847 |
+
constexpr size_t constexpr_sum() { return 0; }
|
| 848 |
+
template <typename T, typename... Ts>
|
| 849 |
+
constexpr size_t constexpr_sum(T n, Ts... ns) {
|
| 850 |
+
return size_t{n} + constexpr_sum(ns...);
|
| 851 |
+
}
|
| 852 |
+
#endif
|
| 853 |
+
|
| 854 |
+
PYBIND11_NAMESPACE_BEGIN(constexpr_impl)
|
| 855 |
+
/// Implementation details for constexpr functions
|
| 856 |
+
constexpr int first(int i) { return i; }
|
| 857 |
+
template <typename T, typename... Ts>
|
| 858 |
+
constexpr int first(int i, T v, Ts... vs) {
|
| 859 |
+
return v ? i : first(i + 1, vs...);
|
| 860 |
+
}
|
| 861 |
+
|
| 862 |
+
constexpr int last(int /*i*/, int result) { return result; }
|
| 863 |
+
template <typename T, typename... Ts>
|
| 864 |
+
constexpr int last(int i, int result, T v, Ts... vs) {
|
| 865 |
+
return last(i + 1, v ? i : result, vs...);
|
| 866 |
+
}
|
| 867 |
+
PYBIND11_NAMESPACE_END(constexpr_impl)
|
| 868 |
+
|
| 869 |
+
/// Return the index of the first type in Ts which satisfies Predicate<T>.
|
| 870 |
+
/// Returns sizeof...(Ts) if none match.
|
| 871 |
+
template <template <typename> class Predicate, typename... Ts>
|
| 872 |
+
constexpr int constexpr_first() {
|
| 873 |
+
return constexpr_impl::first(0, Predicate<Ts>::value...);
|
| 874 |
+
}
|
| 875 |
+
|
| 876 |
+
/// Return the index of the last type in Ts which satisfies Predicate<T>, or -1 if none match.
|
| 877 |
+
template <template <typename> class Predicate, typename... Ts>
|
| 878 |
+
constexpr int constexpr_last() {
|
| 879 |
+
return constexpr_impl::last(0, -1, Predicate<Ts>::value...);
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
+
/// Return the Nth element from the parameter pack
|
| 883 |
+
template <size_t N, typename T, typename... Ts>
|
| 884 |
+
struct pack_element {
|
| 885 |
+
using type = typename pack_element<N - 1, Ts...>::type;
|
| 886 |
+
};
|
| 887 |
+
template <typename T, typename... Ts>
|
| 888 |
+
struct pack_element<0, T, Ts...> {
|
| 889 |
+
using type = T;
|
| 890 |
+
};
|
| 891 |
+
|
| 892 |
+
/// Return the one and only type which matches the predicate, or Default if none match.
|
| 893 |
+
/// If more than one type matches the predicate, fail at compile-time.
|
| 894 |
+
template <template <typename> class Predicate, typename Default, typename... Ts>
|
| 895 |
+
struct exactly_one {
|
| 896 |
+
static constexpr auto found = constexpr_sum(Predicate<Ts>::value...);
|
| 897 |
+
static_assert(found <= 1, "Found more than one type matching the predicate");
|
| 898 |
+
|
| 899 |
+
static constexpr auto index = found ? constexpr_first<Predicate, Ts...>() : 0;
|
| 900 |
+
using type = conditional_t<found, typename pack_element<index, Ts...>::type, Default>;
|
| 901 |
+
};
|
| 902 |
+
template <template <typename> class P, typename Default>
|
| 903 |
+
struct exactly_one<P, Default> {
|
| 904 |
+
using type = Default;
|
| 905 |
+
};
|
| 906 |
+
|
| 907 |
+
template <template <typename> class Predicate, typename Default, typename... Ts>
|
| 908 |
+
using exactly_one_t = typename exactly_one<Predicate, Default, Ts...>::type;
|
| 909 |
+
|
| 910 |
+
/// Defer the evaluation of type T until types Us are instantiated
|
| 911 |
+
template <typename T, typename... /*Us*/>
|
| 912 |
+
struct deferred_type {
|
| 913 |
+
using type = T;
|
| 914 |
+
};
|
| 915 |
+
template <typename T, typename... Us>
|
| 916 |
+
using deferred_t = typename deferred_type<T, Us...>::type;
|
| 917 |
+
|
| 918 |
+
/// Like is_base_of, but requires a strict base (i.e. `is_strict_base_of<T, T>::value == false`,
|
| 919 |
+
/// unlike `std::is_base_of`)
|
| 920 |
+
template <typename Base, typename Derived>
|
| 921 |
+
using is_strict_base_of
|
| 922 |
+
= bool_constant<std::is_base_of<Base, Derived>::value && !std::is_same<Base, Derived>::value>;
|
| 923 |
+
|
| 924 |
+
/// Like is_base_of, but also requires that the base type is accessible (i.e. that a Derived
|
| 925 |
+
/// pointer can be converted to a Base pointer) For unions, `is_base_of<T, T>::value` is False, so
|
| 926 |
+
/// we need to check `is_same` as well.
|
| 927 |
+
template <typename Base, typename Derived>
|
| 928 |
+
using is_accessible_base_of
|
| 929 |
+
= bool_constant<(std::is_same<Base, Derived>::value || std::is_base_of<Base, Derived>::value)
|
| 930 |
+
&& std::is_convertible<Derived *, Base *>::value>;
|
| 931 |
+
|
| 932 |
+
template <template <typename...> class Base>
|
| 933 |
+
struct is_template_base_of_impl {
|
| 934 |
+
template <typename... Us>
|
| 935 |
+
static std::true_type check(Base<Us...> *);
|
| 936 |
+
static std::false_type check(...);
|
| 937 |
+
};
|
| 938 |
+
|
| 939 |
+
/// Check if a template is the base of a type. For example:
|
| 940 |
+
/// `is_template_base_of<Base, T>` is true if `struct T : Base<U> {}` where U can be anything
|
| 941 |
+
template <template <typename...> class Base, typename T>
|
| 942 |
+
// Sadly, all MSVC versions incl. 2022 need the workaround, even in C++20 mode.
|
| 943 |
+
// See also: https://github.com/pybind/pybind11/pull/3741
|
| 944 |
+
#if !defined(_MSC_VER)
|
| 945 |
+
using is_template_base_of
|
| 946 |
+
= decltype(is_template_base_of_impl<Base>::check((intrinsic_t<T> *) nullptr));
|
| 947 |
+
#else
|
| 948 |
+
struct is_template_base_of
|
| 949 |
+
: decltype(is_template_base_of_impl<Base>::check((intrinsic_t<T> *) nullptr)){};
|
| 950 |
+
#endif
|
| 951 |
+
|
| 952 |
+
/// Check if T is an instantiation of the template `Class`. For example:
|
| 953 |
+
/// `is_instantiation<shared_ptr, T>` is true if `T == shared_ptr<U>` where U can be anything.
|
| 954 |
+
template <template <typename...> class Class, typename T>
|
| 955 |
+
struct is_instantiation : std::false_type {};
|
| 956 |
+
template <template <typename...> class Class, typename... Us>
|
| 957 |
+
struct is_instantiation<Class, Class<Us...>> : std::true_type {};
|
| 958 |
+
|
| 959 |
+
/// Check if T is std::shared_ptr<U> where U can be anything
|
| 960 |
+
template <typename T>
|
| 961 |
+
using is_shared_ptr = is_instantiation<std::shared_ptr, T>;
|
| 962 |
+
|
| 963 |
+
/// Check if T looks like an input iterator
|
| 964 |
+
template <typename T, typename = void>
|
| 965 |
+
struct is_input_iterator : std::false_type {};
|
| 966 |
+
template <typename T>
|
| 967 |
+
struct is_input_iterator<T,
|
| 968 |
+
void_t<decltype(*std::declval<T &>()), decltype(++std::declval<T &>())>>
|
| 969 |
+
: std::true_type {};
|
| 970 |
+
|
| 971 |
+
template <typename T>
|
| 972 |
+
using is_function_pointer
|
| 973 |
+
= bool_constant<std::is_pointer<T>::value
|
| 974 |
+
&& std::is_function<typename std::remove_pointer<T>::type>::value>;
|
| 975 |
+
|
| 976 |
+
template <typename F>
|
| 977 |
+
struct strip_function_object {
|
| 978 |
+
// If you are encountering an
|
| 979 |
+
// 'error: name followed by "::" must be a class or namespace name'
|
| 980 |
+
// with the Intel compiler and a noexcept function here,
|
| 981 |
+
// try to use noexcept(true) instead of plain noexcept.
|
| 982 |
+
using type = typename remove_class<decltype(&F::operator())>::type;
|
| 983 |
+
};
|
| 984 |
+
|
| 985 |
+
// Extracts the function signature from a function, function pointer or lambda.
|
| 986 |
+
template <typename Function, typename F = remove_reference_t<Function>>
|
| 987 |
+
using function_signature_t = conditional_t<
|
| 988 |
+
std::is_function<F>::value,
|
| 989 |
+
F,
|
| 990 |
+
typename conditional_t<std::is_pointer<F>::value || std::is_member_pointer<F>::value,
|
| 991 |
+
std::remove_pointer<F>,
|
| 992 |
+
strip_function_object<F>>::type>;
|
| 993 |
+
|
| 994 |
+
/// Returns true if the type looks like a lambda: that is, isn't a function, pointer or member
|
| 995 |
+
/// pointer. Note that this can catch all sorts of other things, too; this is intended to be used
|
| 996 |
+
/// in a place where passing a lambda makes sense.
|
| 997 |
+
template <typename T>
|
| 998 |
+
using is_lambda = satisfies_none_of<remove_reference_t<T>,
|
| 999 |
+
std::is_function,
|
| 1000 |
+
std::is_pointer,
|
| 1001 |
+
std::is_member_pointer>;
|
| 1002 |
+
|
| 1003 |
+
// [workaround(intel)] Internal error on fold expression
|
| 1004 |
+
/// Apply a function over each element of a parameter pack
|
| 1005 |
+
#if defined(__cpp_fold_expressions) && !defined(__INTEL_COMPILER)
|
| 1006 |
+
// Intel compiler produces an internal error on this fold expression (tested with ICC 19.0.2)
|
| 1007 |
+
# define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) (((PATTERN), void()), ...)
|
| 1008 |
+
#else
|
| 1009 |
+
using expand_side_effects = bool[];
|
| 1010 |
+
# define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) \
|
| 1011 |
+
(void) pybind11::detail::expand_side_effects { ((PATTERN), void(), false)..., false }
|
| 1012 |
+
#endif
|
| 1013 |
+
|
| 1014 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 1015 |
+
|
| 1016 |
+
/// C++ bindings of builtin Python exceptions
|
| 1017 |
+
class PYBIND11_EXPORT_EXCEPTION builtin_exception : public std::runtime_error {
|
| 1018 |
+
public:
|
| 1019 |
+
using std::runtime_error::runtime_error;
|
| 1020 |
+
/// Set the error using the Python C API
|
| 1021 |
+
virtual void set_error() const = 0;
|
| 1022 |
+
};
|
| 1023 |
+
|
| 1024 |
+
#define PYBIND11_RUNTIME_EXCEPTION(name, type) \
|
| 1025 |
+
class PYBIND11_EXPORT_EXCEPTION name : public builtin_exception { \
|
| 1026 |
+
public: \
|
| 1027 |
+
using builtin_exception::builtin_exception; \
|
| 1028 |
+
name() : name("") {} \
|
| 1029 |
+
void set_error() const override { PyErr_SetString(type, what()); } \
|
| 1030 |
+
};
|
| 1031 |
+
|
| 1032 |
+
PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration)
|
| 1033 |
+
PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError)
|
| 1034 |
+
PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError)
|
| 1035 |
+
PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError)
|
| 1036 |
+
PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError)
|
| 1037 |
+
PYBIND11_RUNTIME_EXCEPTION(buffer_error, PyExc_BufferError)
|
| 1038 |
+
PYBIND11_RUNTIME_EXCEPTION(import_error, PyExc_ImportError)
|
| 1039 |
+
PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError)
|
| 1040 |
+
PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or
|
| 1041 |
+
/// handle::call fail due to a type
|
| 1042 |
+
/// casting error
|
| 1043 |
+
PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally
|
| 1044 |
+
|
| 1045 |
+
[[noreturn]] PYBIND11_NOINLINE void pybind11_fail(const char *reason) {
|
| 1046 |
+
assert(!PyErr_Occurred());
|
| 1047 |
+
throw std::runtime_error(reason);
|
| 1048 |
+
}
|
| 1049 |
+
[[noreturn]] PYBIND11_NOINLINE void pybind11_fail(const std::string &reason) {
|
| 1050 |
+
assert(!PyErr_Occurred());
|
| 1051 |
+
throw std::runtime_error(reason);
|
| 1052 |
+
}
|
| 1053 |
+
|
| 1054 |
+
template <typename T, typename SFINAE = void>
|
| 1055 |
+
struct format_descriptor {};
|
| 1056 |
+
|
| 1057 |
+
template <typename T>
|
| 1058 |
+
struct format_descriptor<
|
| 1059 |
+
T,
|
| 1060 |
+
detail::enable_if_t<detail::is_same_ignoring_cvref<T, PyObject *>::value>> {
|
| 1061 |
+
static constexpr const char c = 'O';
|
| 1062 |
+
static constexpr const char value[2] = {c, '\0'};
|
| 1063 |
+
static std::string format() { return std::string(1, c); }
|
| 1064 |
+
};
|
| 1065 |
+
|
| 1066 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 1067 |
+
// Returns the index of the given type in the type char array below, and in the list in numpy.h
|
| 1068 |
+
// The order here is: bool; 8 ints ((signed,unsigned)x(8,16,32,64)bits); float,double,long double;
|
| 1069 |
+
// complex float,double,long double. Note that the long double types only participate when long
|
| 1070 |
+
// double is actually longer than double (it isn't under MSVC).
|
| 1071 |
+
// NB: not only the string below but also complex.h and numpy.h rely on this order.
|
| 1072 |
+
template <typename T, typename SFINAE = void>
|
| 1073 |
+
struct is_fmt_numeric {
|
| 1074 |
+
static constexpr bool value = false;
|
| 1075 |
+
};
|
| 1076 |
+
template <typename T>
|
| 1077 |
+
struct is_fmt_numeric<T, enable_if_t<std::is_arithmetic<T>::value>> {
|
| 1078 |
+
static constexpr bool value = true;
|
| 1079 |
+
static constexpr int index
|
| 1080 |
+
= std::is_same<T, bool>::value
|
| 1081 |
+
? 0
|
| 1082 |
+
: 1
|
| 1083 |
+
+ (std::is_integral<T>::value
|
| 1084 |
+
? detail::log2(sizeof(T)) * 2 + std::is_unsigned<T>::value
|
| 1085 |
+
: 8
|
| 1086 |
+
+ (std::is_same<T, double>::value ? 1
|
| 1087 |
+
: std::is_same<T, long double>::value ? 2
|
| 1088 |
+
: 0));
|
| 1089 |
+
};
|
| 1090 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 1091 |
+
|
| 1092 |
+
template <typename T>
|
| 1093 |
+
struct format_descriptor<T, detail::enable_if_t<std::is_arithmetic<T>::value>> {
|
| 1094 |
+
static constexpr const char c = "?bBhHiIqQfdg"[detail::is_fmt_numeric<T>::index];
|
| 1095 |
+
static constexpr const char value[2] = {c, '\0'};
|
| 1096 |
+
static std::string format() { return std::string(1, c); }
|
| 1097 |
+
};
|
| 1098 |
+
|
| 1099 |
+
#if !defined(PYBIND11_CPP17)
|
| 1100 |
+
|
| 1101 |
+
template <typename T>
|
| 1102 |
+
constexpr const char
|
| 1103 |
+
format_descriptor<T, detail::enable_if_t<std::is_arithmetic<T>::value>>::value[2];
|
| 1104 |
+
|
| 1105 |
+
#endif
|
| 1106 |
+
|
| 1107 |
+
/// RAII wrapper that temporarily clears any Python error state
|
| 1108 |
+
struct error_scope {
|
| 1109 |
+
PyObject *type, *value, *trace;
|
| 1110 |
+
error_scope() { PyErr_Fetch(&type, &value, &trace); }
|
| 1111 |
+
error_scope(const error_scope &) = delete;
|
| 1112 |
+
error_scope &operator=(const error_scope &) = delete;
|
| 1113 |
+
~error_scope() { PyErr_Restore(type, value, trace); }
|
| 1114 |
+
};
|
| 1115 |
+
|
| 1116 |
+
/// Dummy destructor wrapper that can be used to expose classes with a private destructor
|
| 1117 |
+
struct nodelete {
|
| 1118 |
+
template <typename T>
|
| 1119 |
+
void operator()(T *) {}
|
| 1120 |
+
};
|
| 1121 |
+
|
| 1122 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 1123 |
+
template <typename... Args>
|
| 1124 |
+
struct overload_cast_impl {
|
| 1125 |
+
template <typename Return>
|
| 1126 |
+
constexpr auto operator()(Return (*pf)(Args...)) const noexcept -> decltype(pf) {
|
| 1127 |
+
return pf;
|
| 1128 |
+
}
|
| 1129 |
+
|
| 1130 |
+
template <typename Return, typename Class>
|
| 1131 |
+
constexpr auto operator()(Return (Class::*pmf)(Args...),
|
| 1132 |
+
std::false_type = {}) const noexcept -> decltype(pmf) {
|
| 1133 |
+
return pmf;
|
| 1134 |
+
}
|
| 1135 |
+
|
| 1136 |
+
template <typename Return, typename Class>
|
| 1137 |
+
constexpr auto operator()(Return (Class::*pmf)(Args...) const,
|
| 1138 |
+
std::true_type) const noexcept -> decltype(pmf) {
|
| 1139 |
+
return pmf;
|
| 1140 |
+
}
|
| 1141 |
+
};
|
| 1142 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 1143 |
+
|
| 1144 |
+
// overload_cast requires variable templates: C++14
|
| 1145 |
+
#if defined(PYBIND11_CPP14)
|
| 1146 |
+
# define PYBIND11_OVERLOAD_CAST 1
|
| 1147 |
+
/// Syntax sugar for resolving overloaded function pointers:
|
| 1148 |
+
/// - regular: static_cast<Return (Class::*)(Arg0, Arg1, Arg2)>(&Class::func)
|
| 1149 |
+
/// - sweet: overload_cast<Arg0, Arg1, Arg2>(&Class::func)
|
| 1150 |
+
template <typename... Args>
|
| 1151 |
+
static constexpr detail::overload_cast_impl<Args...> overload_cast{};
|
| 1152 |
+
#endif
|
| 1153 |
+
|
| 1154 |
+
/// Const member function selector for overload_cast
|
| 1155 |
+
/// - regular: static_cast<Return (Class::*)(Arg) const>(&Class::func)
|
| 1156 |
+
/// - sweet: overload_cast<Arg>(&Class::func, const_)
|
| 1157 |
+
static constexpr auto const_ = std::true_type{};
|
| 1158 |
+
|
| 1159 |
+
#if !defined(PYBIND11_CPP14) // no overload_cast: providing something that static_assert-fails:
|
| 1160 |
+
template <typename... Args>
|
| 1161 |
+
struct overload_cast {
|
| 1162 |
+
static_assert(detail::deferred_t<std::false_type, Args...>::value,
|
| 1163 |
+
"pybind11::overload_cast<...> requires compiling in C++14 mode");
|
| 1164 |
+
};
|
| 1165 |
+
#endif // overload_cast
|
| 1166 |
+
|
| 1167 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 1168 |
+
|
| 1169 |
+
// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from
|
| 1170 |
+
// any standard container (or C-style array) supporting std::begin/std::end, any singleton
|
| 1171 |
+
// arithmetic type (if T is arithmetic), or explicitly constructible from an iterator pair.
|
| 1172 |
+
template <typename T>
|
| 1173 |
+
class any_container {
|
| 1174 |
+
std::vector<T> v;
|
| 1175 |
+
|
| 1176 |
+
public:
|
| 1177 |
+
any_container() = default;
|
| 1178 |
+
|
| 1179 |
+
// Can construct from a pair of iterators
|
| 1180 |
+
template <typename It, typename = enable_if_t<is_input_iterator<It>::value>>
|
| 1181 |
+
any_container(It first, It last) : v(first, last) {}
|
| 1182 |
+
|
| 1183 |
+
// Implicit conversion constructor from any arbitrary container type
|
| 1184 |
+
// with values convertible to T
|
| 1185 |
+
template <typename Container,
|
| 1186 |
+
typename = enable_if_t<
|
| 1187 |
+
std::is_convertible<decltype(*std::begin(std::declval<const Container &>())),
|
| 1188 |
+
T>::value>>
|
| 1189 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 1190 |
+
any_container(const Container &c) : any_container(std::begin(c), std::end(c)) {}
|
| 1191 |
+
|
| 1192 |
+
// initializer_list's aren't deducible, so don't get matched by the above template;
|
| 1193 |
+
// we need this to explicitly allow implicit conversion from one:
|
| 1194 |
+
template <typename TIn, typename = enable_if_t<std::is_convertible<TIn, T>::value>>
|
| 1195 |
+
any_container(const std::initializer_list<TIn> &c) : any_container(c.begin(), c.end()) {}
|
| 1196 |
+
|
| 1197 |
+
// Avoid copying if given an rvalue vector of the correct type.
|
| 1198 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 1199 |
+
any_container(std::vector<T> &&v) : v(std::move(v)) {}
|
| 1200 |
+
|
| 1201 |
+
// Moves the vector out of an rvalue any_container
|
| 1202 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 1203 |
+
operator std::vector<T> &&() && { return std::move(v); }
|
| 1204 |
+
|
| 1205 |
+
// Dereferencing obtains a reference to the underlying vector
|
| 1206 |
+
std::vector<T> &operator*() { return v; }
|
| 1207 |
+
const std::vector<T> &operator*() const { return v; }
|
| 1208 |
+
|
| 1209 |
+
// -> lets you call methods on the underlying vector
|
| 1210 |
+
std::vector<T> *operator->() { return &v; }
|
| 1211 |
+
const std::vector<T> *operator->() const { return &v; }
|
| 1212 |
+
};
|
| 1213 |
+
|
| 1214 |
+
// Forward-declaration; see detail/class.h
|
| 1215 |
+
std::string get_fully_qualified_tp_name(PyTypeObject *);
|
| 1216 |
+
|
| 1217 |
+
template <typename T>
|
| 1218 |
+
inline static std::shared_ptr<T>
|
| 1219 |
+
try_get_shared_from_this(std::enable_shared_from_this<T> *holder_value_ptr) {
|
| 1220 |
+
// Pre C++17, this code path exploits undefined behavior, but is known to work on many platforms.
|
| 1221 |
+
// Use at your own risk!
|
| 1222 |
+
// See also https://en.cppreference.com/w/cpp/memory/enable_shared_from_this, and in particular
|
| 1223 |
+
// the `std::shared_ptr<Good> gp1 = not_so_good.getptr();` and `try`-`catch` parts of the example.
|
| 1224 |
+
#if defined(__cpp_lib_enable_shared_from_this) && (!defined(_MSC_VER) || _MSC_VER >= 1912)
|
| 1225 |
+
return holder_value_ptr->weak_from_this().lock();
|
| 1226 |
+
#else
|
| 1227 |
+
try {
|
| 1228 |
+
return holder_value_ptr->shared_from_this();
|
| 1229 |
+
} catch (const std::bad_weak_ptr &) {
|
| 1230 |
+
return nullptr;
|
| 1231 |
+
}
|
| 1232 |
+
#endif
|
| 1233 |
+
}
|
| 1234 |
+
|
| 1235 |
+
// For silencing "unused" compiler warnings in special situations.
|
| 1236 |
+
template <typename... Args>
|
| 1237 |
+
#if defined(_MSC_VER) && _MSC_VER < 1920 // MSVC 2017
|
| 1238 |
+
constexpr
|
| 1239 |
+
#endif
|
| 1240 |
+
inline void
|
| 1241 |
+
silence_unused_warnings(Args &&...) {
|
| 1242 |
+
}
|
| 1243 |
+
|
| 1244 |
+
// MSVC warning C4100: Unreferenced formal parameter
|
| 1245 |
+
#if defined(_MSC_VER) && _MSC_VER <= 1916
|
| 1246 |
+
# define PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(...) \
|
| 1247 |
+
detail::silence_unused_warnings(__VA_ARGS__)
|
| 1248 |
+
#else
|
| 1249 |
+
# define PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(...)
|
| 1250 |
+
#endif
|
| 1251 |
+
|
| 1252 |
+
// GCC -Wunused-but-set-parameter All GCC versions (as of July 2021).
|
| 1253 |
+
#if defined(__GNUG__) && !defined(__clang__) && !defined(__INTEL_COMPILER)
|
| 1254 |
+
# define PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(...) \
|
| 1255 |
+
detail::silence_unused_warnings(__VA_ARGS__)
|
| 1256 |
+
#else
|
| 1257 |
+
# define PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(...)
|
| 1258 |
+
#endif
|
| 1259 |
+
|
| 1260 |
+
#if defined(__clang__) \
|
| 1261 |
+
&& (defined(__apple_build_version__) /* AppleClang 13.0.0.13000029 was the only data point \
|
| 1262 |
+
available. */ \
|
| 1263 |
+
|| (__clang_major__ >= 7 \
|
| 1264 |
+
&& __clang_major__ <= 12) /* Clang 3, 5, 13, 14, 15 do not generate the warning. */ \
|
| 1265 |
+
)
|
| 1266 |
+
# define PYBIND11_DETECTED_CLANG_WITH_MISLEADING_CALL_STD_MOVE_EXPLICITLY_WARNING
|
| 1267 |
+
// Example:
|
| 1268 |
+
// tests/test_kwargs_and_defaults.cpp:46:68: error: local variable 'args' will be copied despite
|
| 1269 |
+
// being returned by name [-Werror,-Wreturn-std-move]
|
| 1270 |
+
// m.def("args_function", [](py::args args) -> py::tuple { return args; });
|
| 1271 |
+
// ^~~~
|
| 1272 |
+
// test_kwargs_and_defaults.cpp:46:68: note: call 'std::move' explicitly to avoid copying
|
| 1273 |
+
// m.def("args_function", [](py::args args) -> py::tuple { return args; });
|
| 1274 |
+
// ^~~~
|
| 1275 |
+
// std::move(args)
|
| 1276 |
+
#endif
|
| 1277 |
+
|
| 1278 |
+
// Pybind offers detailed error messages by default for all builts that are debug (through the
|
| 1279 |
+
// negation of NDEBUG). This can also be manually enabled by users, for any builds, through
|
| 1280 |
+
// defining PYBIND11_DETAILED_ERROR_MESSAGES. This information is primarily useful for those
|
| 1281 |
+
// who are writing (as opposed to merely using) libraries that use pybind11.
|
| 1282 |
+
#if !defined(PYBIND11_DETAILED_ERROR_MESSAGES) && !defined(NDEBUG)
|
| 1283 |
+
# define PYBIND11_DETAILED_ERROR_MESSAGES
|
| 1284 |
+
#endif
|
| 1285 |
+
|
| 1286 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 1287 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/detail/cpp_conduit.h
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024 The pybind Community.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <pybind11/pytypes.h>
|
| 6 |
+
|
| 7 |
+
#include "common.h"
|
| 8 |
+
#include "internals.h"
|
| 9 |
+
|
| 10 |
+
#include <typeinfo>
|
| 11 |
+
|
| 12 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 13 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 14 |
+
|
| 15 |
+
// Forward declaration needed here: Refactoring opportunity.
|
| 16 |
+
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *);
|
| 17 |
+
|
| 18 |
+
inline bool type_is_managed_by_our_internals(PyTypeObject *type_obj) {
|
| 19 |
+
#if defined(PYPY_VERSION)
|
| 20 |
+
auto &internals = get_internals();
|
| 21 |
+
return bool(internals.registered_types_py.find(type_obj)
|
| 22 |
+
!= internals.registered_types_py.end());
|
| 23 |
+
#else
|
| 24 |
+
return bool(type_obj->tp_new == pybind11_object_new);
|
| 25 |
+
#endif
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
inline bool is_instance_method_of_type(PyTypeObject *type_obj, PyObject *attr_name) {
|
| 29 |
+
PyObject *descr = _PyType_Lookup(type_obj, attr_name);
|
| 30 |
+
return bool((descr != nullptr) && PyInstanceMethod_Check(descr));
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
inline object try_get_cpp_conduit_method(PyObject *obj) {
|
| 34 |
+
if (PyType_Check(obj)) {
|
| 35 |
+
return object();
|
| 36 |
+
}
|
| 37 |
+
PyTypeObject *type_obj = Py_TYPE(obj);
|
| 38 |
+
str attr_name("_pybind11_conduit_v1_");
|
| 39 |
+
bool assumed_to_be_callable = false;
|
| 40 |
+
if (type_is_managed_by_our_internals(type_obj)) {
|
| 41 |
+
if (!is_instance_method_of_type(type_obj, attr_name.ptr())) {
|
| 42 |
+
return object();
|
| 43 |
+
}
|
| 44 |
+
assumed_to_be_callable = true;
|
| 45 |
+
}
|
| 46 |
+
PyObject *method = PyObject_GetAttr(obj, attr_name.ptr());
|
| 47 |
+
if (method == nullptr) {
|
| 48 |
+
PyErr_Clear();
|
| 49 |
+
return object();
|
| 50 |
+
}
|
| 51 |
+
if (!assumed_to_be_callable && PyCallable_Check(method) == 0) {
|
| 52 |
+
Py_DECREF(method);
|
| 53 |
+
return object();
|
| 54 |
+
}
|
| 55 |
+
return reinterpret_steal<object>(method);
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
inline void *try_raw_pointer_ephemeral_from_cpp_conduit(handle src,
|
| 59 |
+
const std::type_info *cpp_type_info) {
|
| 60 |
+
object method = try_get_cpp_conduit_method(src.ptr());
|
| 61 |
+
if (method) {
|
| 62 |
+
capsule cpp_type_info_capsule(const_cast<void *>(static_cast<const void *>(cpp_type_info)),
|
| 63 |
+
typeid(std::type_info).name());
|
| 64 |
+
object cpp_conduit = method(bytes(PYBIND11_PLATFORM_ABI_ID),
|
| 65 |
+
cpp_type_info_capsule,
|
| 66 |
+
bytes("raw_pointer_ephemeral"));
|
| 67 |
+
if (isinstance<capsule>(cpp_conduit)) {
|
| 68 |
+
return reinterpret_borrow<capsule>(cpp_conduit).get_pointer();
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
return nullptr;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
#define PYBIND11_HAS_CPP_CONDUIT 1
|
| 75 |
+
|
| 76 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 77 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/detail/descr.h
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "common.h"
|
| 13 |
+
|
| 14 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 15 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 16 |
+
|
| 17 |
+
#if !defined(_MSC_VER)
|
| 18 |
+
# define PYBIND11_DESCR_CONSTEXPR static constexpr
|
| 19 |
+
#else
|
| 20 |
+
# define PYBIND11_DESCR_CONSTEXPR const
|
| 21 |
+
#endif
|
| 22 |
+
|
| 23 |
+
/* Concatenate type signatures at compile time */
|
| 24 |
+
template <size_t N, typename... Ts>
|
| 25 |
+
struct descr {
|
| 26 |
+
char text[N + 1]{'\0'};
|
| 27 |
+
|
| 28 |
+
constexpr descr() = default;
|
| 29 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 30 |
+
constexpr descr(char const (&s)[N + 1]) : descr(s, make_index_sequence<N>()) {}
|
| 31 |
+
|
| 32 |
+
template <size_t... Is>
|
| 33 |
+
constexpr descr(char const (&s)[N + 1], index_sequence<Is...>) : text{s[Is]..., '\0'} {}
|
| 34 |
+
|
| 35 |
+
template <typename... Chars>
|
| 36 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 37 |
+
constexpr descr(char c, Chars... cs) : text{c, static_cast<char>(cs)..., '\0'} {}
|
| 38 |
+
|
| 39 |
+
static constexpr std::array<const std::type_info *, sizeof...(Ts) + 1> types() {
|
| 40 |
+
return {{&typeid(Ts)..., nullptr}};
|
| 41 |
+
}
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2, size_t... Is1, size_t... Is2>
|
| 45 |
+
constexpr descr<N1 + N2, Ts1..., Ts2...> plus_impl(const descr<N1, Ts1...> &a,
|
| 46 |
+
const descr<N2, Ts2...> &b,
|
| 47 |
+
index_sequence<Is1...>,
|
| 48 |
+
index_sequence<Is2...>) {
|
| 49 |
+
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(b);
|
| 50 |
+
return {a.text[Is1]..., b.text[Is2]...};
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2>
|
| 54 |
+
constexpr descr<N1 + N2, Ts1..., Ts2...> operator+(const descr<N1, Ts1...> &a,
|
| 55 |
+
const descr<N2, Ts2...> &b) {
|
| 56 |
+
return plus_impl(a, b, make_index_sequence<N1>(), make_index_sequence<N2>());
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
template <size_t N>
|
| 60 |
+
constexpr descr<N - 1> const_name(char const (&text)[N]) {
|
| 61 |
+
return descr<N - 1>(text);
|
| 62 |
+
}
|
| 63 |
+
constexpr descr<0> const_name(char const (&)[1]) { return {}; }
|
| 64 |
+
|
| 65 |
+
template <size_t Rem, size_t... Digits>
|
| 66 |
+
struct int_to_str : int_to_str<Rem / 10, Rem % 10, Digits...> {};
|
| 67 |
+
template <size_t... Digits>
|
| 68 |
+
struct int_to_str<0, Digits...> {
|
| 69 |
+
// WARNING: This only works with C++17 or higher.
|
| 70 |
+
static constexpr auto digits = descr<sizeof...(Digits)>(('0' + Digits)...);
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
// Ternary description (like std::conditional)
|
| 74 |
+
template <bool B, size_t N1, size_t N2>
|
| 75 |
+
constexpr enable_if_t<B, descr<N1 - 1>> const_name(char const (&text1)[N1], char const (&)[N2]) {
|
| 76 |
+
return const_name(text1);
|
| 77 |
+
}
|
| 78 |
+
template <bool B, size_t N1, size_t N2>
|
| 79 |
+
constexpr enable_if_t<!B, descr<N2 - 1>> const_name(char const (&)[N1], char const (&text2)[N2]) {
|
| 80 |
+
return const_name(text2);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template <bool B, typename T1, typename T2>
|
| 84 |
+
constexpr enable_if_t<B, T1> const_name(const T1 &d, const T2 &) {
|
| 85 |
+
return d;
|
| 86 |
+
}
|
| 87 |
+
template <bool B, typename T1, typename T2>
|
| 88 |
+
constexpr enable_if_t<!B, T2> const_name(const T1 &, const T2 &d) {
|
| 89 |
+
return d;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
template <size_t Size>
|
| 93 |
+
auto constexpr const_name() -> remove_cv_t<decltype(int_to_str<Size / 10, Size % 10>::digits)> {
|
| 94 |
+
return int_to_str<Size / 10, Size % 10>::digits;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
template <typename Type>
|
| 98 |
+
constexpr descr<1, Type> const_name() {
|
| 99 |
+
return {'%'};
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// If "_" is defined as a macro, py::detail::_ cannot be provided.
|
| 103 |
+
// It is therefore best to use py::detail::const_name universally.
|
| 104 |
+
// This block is for backward compatibility only.
|
| 105 |
+
// (The const_name code is repeated to avoid introducing a "_" #define ourselves.)
|
| 106 |
+
#ifndef _
|
| 107 |
+
# define PYBIND11_DETAIL_UNDERSCORE_BACKWARD_COMPATIBILITY
|
| 108 |
+
template <size_t N>
|
| 109 |
+
constexpr descr<N - 1> _(char const (&text)[N]) {
|
| 110 |
+
return const_name<N>(text);
|
| 111 |
+
}
|
| 112 |
+
template <bool B, size_t N1, size_t N2>
|
| 113 |
+
constexpr enable_if_t<B, descr<N1 - 1>> _(char const (&text1)[N1], char const (&text2)[N2]) {
|
| 114 |
+
return const_name<B, N1, N2>(text1, text2);
|
| 115 |
+
}
|
| 116 |
+
template <bool B, size_t N1, size_t N2>
|
| 117 |
+
constexpr enable_if_t<!B, descr<N2 - 1>> _(char const (&text1)[N1], char const (&text2)[N2]) {
|
| 118 |
+
return const_name<B, N1, N2>(text1, text2);
|
| 119 |
+
}
|
| 120 |
+
template <bool B, typename T1, typename T2>
|
| 121 |
+
constexpr enable_if_t<B, T1> _(const T1 &d1, const T2 &d2) {
|
| 122 |
+
return const_name<B, T1, T2>(d1, d2);
|
| 123 |
+
}
|
| 124 |
+
template <bool B, typename T1, typename T2>
|
| 125 |
+
constexpr enable_if_t<!B, T2> _(const T1 &d1, const T2 &d2) {
|
| 126 |
+
return const_name<B, T1, T2>(d1, d2);
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
template <size_t Size>
|
| 130 |
+
auto constexpr _() -> remove_cv_t<decltype(int_to_str<Size / 10, Size % 10>::digits)> {
|
| 131 |
+
return const_name<Size>();
|
| 132 |
+
}
|
| 133 |
+
template <typename Type>
|
| 134 |
+
constexpr descr<1, Type> _() {
|
| 135 |
+
return const_name<Type>();
|
| 136 |
+
}
|
| 137 |
+
#endif // #ifndef _
|
| 138 |
+
|
| 139 |
+
constexpr descr<0> concat() { return {}; }
|
| 140 |
+
|
| 141 |
+
template <size_t N, typename... Ts>
|
| 142 |
+
constexpr descr<N, Ts...> concat(const descr<N, Ts...> &descr) {
|
| 143 |
+
return descr;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
#ifdef __cpp_fold_expressions
|
| 147 |
+
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2>
|
| 148 |
+
constexpr descr<N1 + N2 + 2, Ts1..., Ts2...> operator,(const descr<N1, Ts1...> &a,
|
| 149 |
+
const descr<N2, Ts2...> &b) {
|
| 150 |
+
return a + const_name(", ") + b;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
template <size_t N, typename... Ts, typename... Args>
|
| 154 |
+
constexpr auto concat(const descr<N, Ts...> &d, const Args &...args) {
|
| 155 |
+
return (d, ..., args);
|
| 156 |
+
}
|
| 157 |
+
#else
|
| 158 |
+
template <size_t N, typename... Ts, typename... Args>
|
| 159 |
+
constexpr auto concat(const descr<N, Ts...> &d,
|
| 160 |
+
const Args &...args) -> decltype(std::declval<descr<N + 2, Ts...>>()
|
| 161 |
+
+ concat(args...)) {
|
| 162 |
+
return d + const_name(", ") + concat(args...);
|
| 163 |
+
}
|
| 164 |
+
#endif
|
| 165 |
+
|
| 166 |
+
template <size_t N, typename... Ts>
|
| 167 |
+
constexpr descr<N + 2, Ts...> type_descr(const descr<N, Ts...> &descr) {
|
| 168 |
+
return const_name("{") + descr + const_name("}");
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 172 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/detail/exception_translation.h
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/detail/exception_translation.h: means to translate C++ exceptions to Python exceptions
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2024 The Pybind Development Team.
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "common.h"
|
| 13 |
+
#include "internals.h"
|
| 14 |
+
|
| 15 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 16 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 17 |
+
|
| 18 |
+
// Apply all the extensions translators from a list
|
| 19 |
+
// Return true if one of the translators completed without raising an exception
|
| 20 |
+
// itself. Return of false indicates that if there are other translators
|
| 21 |
+
// available, they should be tried.
|
| 22 |
+
inline bool apply_exception_translators(std::forward_list<ExceptionTranslator> &translators) {
|
| 23 |
+
auto last_exception = std::current_exception();
|
| 24 |
+
|
| 25 |
+
for (auto &translator : translators) {
|
| 26 |
+
try {
|
| 27 |
+
translator(last_exception);
|
| 28 |
+
return true;
|
| 29 |
+
} catch (...) {
|
| 30 |
+
last_exception = std::current_exception();
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
return false;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
inline void try_translate_exceptions() {
|
| 37 |
+
/* When an exception is caught, give each registered exception
|
| 38 |
+
translator a chance to translate it to a Python exception. First
|
| 39 |
+
all module-local translators will be tried in reverse order of
|
| 40 |
+
registration. If none of the module-locale translators handle
|
| 41 |
+
the exception (or there are no module-locale translators) then
|
| 42 |
+
the global translators will be tried, also in reverse order of
|
| 43 |
+
registration.
|
| 44 |
+
|
| 45 |
+
A translator may choose to do one of the following:
|
| 46 |
+
|
| 47 |
+
- catch the exception and call py::set_error()
|
| 48 |
+
to set a standard (or custom) Python exception, or
|
| 49 |
+
- do nothing and let the exception fall through to the next translator, or
|
| 50 |
+
- delegate translation to the next translator by throwing a new type of exception.
|
| 51 |
+
*/
|
| 52 |
+
|
| 53 |
+
bool handled = with_internals([&](internals &internals) {
|
| 54 |
+
auto &local_exception_translators = get_local_internals().registered_exception_translators;
|
| 55 |
+
if (detail::apply_exception_translators(local_exception_translators)) {
|
| 56 |
+
return true;
|
| 57 |
+
}
|
| 58 |
+
auto &exception_translators = internals.registered_exception_translators;
|
| 59 |
+
if (detail::apply_exception_translators(exception_translators)) {
|
| 60 |
+
return true;
|
| 61 |
+
}
|
| 62 |
+
return false;
|
| 63 |
+
});
|
| 64 |
+
|
| 65 |
+
if (!handled) {
|
| 66 |
+
set_error(PyExc_SystemError, "Exception escaped from default exception translator!");
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 71 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/detail/init.h
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/detail/init.h: init factory function implementation and support code.
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2017 Jason Rhinelander <jason@imaginary.ca>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "class.h"
|
| 13 |
+
|
| 14 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 15 |
+
|
| 16 |
+
PYBIND11_WARNING_DISABLE_MSVC(4127)
|
| 17 |
+
|
| 18 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 19 |
+
|
| 20 |
+
template <>
|
| 21 |
+
class type_caster<value_and_holder> {
|
| 22 |
+
public:
|
| 23 |
+
bool load(handle h, bool) {
|
| 24 |
+
value = reinterpret_cast<value_and_holder *>(h.ptr());
|
| 25 |
+
return true;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
template <typename>
|
| 29 |
+
using cast_op_type = value_and_holder &;
|
| 30 |
+
explicit operator value_and_holder &() { return *value; }
|
| 31 |
+
static constexpr auto name = const_name<value_and_holder>();
|
| 32 |
+
|
| 33 |
+
private:
|
| 34 |
+
value_and_holder *value = nullptr;
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
PYBIND11_NAMESPACE_BEGIN(initimpl)
|
| 38 |
+
|
| 39 |
+
inline void no_nullptr(void *ptr) {
|
| 40 |
+
if (!ptr) {
|
| 41 |
+
throw type_error("pybind11::init(): factory function returned nullptr");
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
// Implementing functions for all forms of py::init<...> and py::init(...)
|
| 46 |
+
template <typename Class>
|
| 47 |
+
using Cpp = typename Class::type;
|
| 48 |
+
template <typename Class>
|
| 49 |
+
using Alias = typename Class::type_alias;
|
| 50 |
+
template <typename Class>
|
| 51 |
+
using Holder = typename Class::holder_type;
|
| 52 |
+
|
| 53 |
+
template <typename Class>
|
| 54 |
+
using is_alias_constructible = std::is_constructible<Alias<Class>, Cpp<Class> &&>;
|
| 55 |
+
|
| 56 |
+
// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance.
|
| 57 |
+
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
|
| 58 |
+
bool is_alias(Cpp<Class> *ptr) {
|
| 59 |
+
return dynamic_cast<Alias<Class> *>(ptr) != nullptr;
|
| 60 |
+
}
|
| 61 |
+
// Failing fallback version of the above for a no-alias class (always returns false)
|
| 62 |
+
template <typename /*Class*/>
|
| 63 |
+
constexpr bool is_alias(void *) {
|
| 64 |
+
return false;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// Constructs and returns a new object; if the given arguments don't map to a constructor, we fall
|
| 68 |
+
// back to brace aggregate initialization so that for aggregate initialization can be used with
|
| 69 |
+
// py::init, e.g. `py::init<int, int>` to initialize a `struct T { int a; int b; }`. For
|
| 70 |
+
// non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually
|
| 71 |
+
// works, but will not do the expected thing when `T` has an `initializer_list<T>` constructor).
|
| 72 |
+
template <typename Class,
|
| 73 |
+
typename... Args,
|
| 74 |
+
detail::enable_if_t<std::is_constructible<Class, Args...>::value, int> = 0>
|
| 75 |
+
inline Class *construct_or_initialize(Args &&...args) {
|
| 76 |
+
return new Class(std::forward<Args>(args)...);
|
| 77 |
+
}
|
| 78 |
+
template <typename Class,
|
| 79 |
+
typename... Args,
|
| 80 |
+
detail::enable_if_t<!std::is_constructible<Class, Args...>::value, int> = 0>
|
| 81 |
+
inline Class *construct_or_initialize(Args &&...args) {
|
| 82 |
+
return new Class{std::forward<Args>(args)...};
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with
|
| 86 |
+
// an alias to provide only a single Cpp factory function as long as the Alias can be
|
| 87 |
+
// constructed from an rvalue reference of the base Cpp type. This means that Alias classes
|
| 88 |
+
// can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to
|
| 89 |
+
// inherit all the base class constructors.
|
| 90 |
+
template <typename Class>
|
| 91 |
+
void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/,
|
| 92 |
+
value_and_holder &v_h,
|
| 93 |
+
Cpp<Class> &&base) {
|
| 94 |
+
v_h.value_ptr() = new Alias<Class>(std::move(base));
|
| 95 |
+
}
|
| 96 |
+
template <typename Class>
|
| 97 |
+
[[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/,
|
| 98 |
+
value_and_holder &,
|
| 99 |
+
Cpp<Class> &&) {
|
| 100 |
+
throw type_error("pybind11::init(): unable to convert returned instance to required "
|
| 101 |
+
"alias class: no `Alias<Class>(Class &&)` constructor available");
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
// Error-generating fallback for factories that don't match one of the below construction
|
| 105 |
+
// mechanisms.
|
| 106 |
+
template <typename Class>
|
| 107 |
+
void construct(...) {
|
| 108 |
+
static_assert(!std::is_same<Class, Class>::value /* always false */,
|
| 109 |
+
"pybind11::init(): init function must return a compatible pointer, "
|
| 110 |
+
"holder, or value");
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// Pointer return v1: the factory function returns a class pointer for a registered class.
|
| 114 |
+
// If we don't need an alias (because this class doesn't have one, or because the final type is
|
| 115 |
+
// inherited on the Python side) we can simply take over ownership. Otherwise we need to try to
|
| 116 |
+
// construct an Alias from the returned base instance.
|
| 117 |
+
template <typename Class>
|
| 118 |
+
void construct(value_and_holder &v_h, Cpp<Class> *ptr, bool need_alias) {
|
| 119 |
+
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(need_alias);
|
| 120 |
+
no_nullptr(ptr);
|
| 121 |
+
if (Class::has_alias && need_alias && !is_alias<Class>(ptr)) {
|
| 122 |
+
// We're going to try to construct an alias by moving the cpp type. Whether or not
|
| 123 |
+
// that succeeds, we still need to destroy the original cpp pointer (either the
|
| 124 |
+
// moved away leftover, if the alias construction works, or the value itself if we
|
| 125 |
+
// throw an error), but we can't just call `delete ptr`: it might have a special
|
| 126 |
+
// deleter, or might be shared_from_this. So we construct a holder around it as if
|
| 127 |
+
// it was a normal instance, then steal the holder away into a local variable; thus
|
| 128 |
+
// the holder and destruction happens when we leave the C++ scope, and the holder
|
| 129 |
+
// class gets to handle the destruction however it likes.
|
| 130 |
+
v_h.value_ptr() = ptr;
|
| 131 |
+
v_h.set_instance_registered(true); // Trick to prevent init_instance from registering it
|
| 132 |
+
// DANGER ZONE BEGIN: exceptions will leave v_h in an invalid state.
|
| 133 |
+
v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder
|
| 134 |
+
Holder<Class> temp_holder(std::move(v_h.holder<Holder<Class>>())); // Steal the holder
|
| 135 |
+
v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null
|
| 136 |
+
v_h.set_instance_registered(false);
|
| 137 |
+
// DANGER ZONE END.
|
| 138 |
+
|
| 139 |
+
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(*ptr));
|
| 140 |
+
} else {
|
| 141 |
+
// Otherwise the type isn't inherited, so we don't need an Alias
|
| 142 |
+
v_h.value_ptr() = ptr;
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
// Pointer return v2: a factory that always returns an alias instance ptr. We simply take over
|
| 147 |
+
// ownership of the pointer.
|
| 148 |
+
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
|
| 149 |
+
void construct(value_and_holder &v_h, Alias<Class> *alias_ptr, bool) {
|
| 150 |
+
no_nullptr(alias_ptr);
|
| 151 |
+
v_h.value_ptr() = static_cast<Cpp<Class> *>(alias_ptr);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
// Holder return: copy its pointer, and move or copy the returned holder into the new instance's
|
| 155 |
+
// holder. This also handles types like std::shared_ptr<T> and std::unique_ptr<T> where T is a
|
| 156 |
+
// derived type (through those holder's implicit conversion from derived class holder
|
| 157 |
+
// constructors).
|
| 158 |
+
template <typename Class>
|
| 159 |
+
void construct(value_and_holder &v_h, Holder<Class> holder, bool need_alias) {
|
| 160 |
+
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(need_alias);
|
| 161 |
+
auto *ptr = holder_helper<Holder<Class>>::get(holder);
|
| 162 |
+
no_nullptr(ptr);
|
| 163 |
+
// If we need an alias, check that the held pointer is actually an alias instance
|
| 164 |
+
if (Class::has_alias && need_alias && !is_alias<Class>(ptr)) {
|
| 165 |
+
throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance "
|
| 166 |
+
"is not an alias instance");
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
v_h.value_ptr() = ptr;
|
| 170 |
+
v_h.type->init_instance(v_h.inst, &holder);
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
// return-by-value version 1: returning a cpp class by value. If the class has an alias and an
|
| 174 |
+
// alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct
|
| 175 |
+
// the alias from the base when needed (i.e. because of Python-side inheritance). When we don't
|
| 176 |
+
// need it, we simply move-construct the cpp value into a new instance.
|
| 177 |
+
template <typename Class>
|
| 178 |
+
void construct(value_and_holder &v_h, Cpp<Class> &&result, bool need_alias) {
|
| 179 |
+
PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(need_alias);
|
| 180 |
+
static_assert(is_move_constructible<Cpp<Class>>::value,
|
| 181 |
+
"pybind11::init() return-by-value factory function requires a movable class");
|
| 182 |
+
if (Class::has_alias && need_alias) {
|
| 183 |
+
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(result));
|
| 184 |
+
} else {
|
| 185 |
+
v_h.value_ptr() = new Cpp<Class>(std::move(result));
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
// return-by-value version 2: returning a value of the alias type itself. We move-construct an
|
| 190 |
+
// Alias instance (even if no the python-side inheritance is involved). The is intended for
|
| 191 |
+
// cases where Alias initialization is always desired.
|
| 192 |
+
template <typename Class>
|
| 193 |
+
void construct(value_and_holder &v_h, Alias<Class> &&result, bool) {
|
| 194 |
+
static_assert(
|
| 195 |
+
is_move_constructible<Alias<Class>>::value,
|
| 196 |
+
"pybind11::init() return-by-alias-value factory function requires a movable alias class");
|
| 197 |
+
v_h.value_ptr() = new Alias<Class>(std::move(result));
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
// Implementing class for py::init<...>()
|
| 201 |
+
template <typename... Args>
|
| 202 |
+
struct constructor {
|
| 203 |
+
template <typename Class, typename... Extra, enable_if_t<!Class::has_alias, int> = 0>
|
| 204 |
+
static void execute(Class &cl, const Extra &...extra) {
|
| 205 |
+
cl.def(
|
| 206 |
+
"__init__",
|
| 207 |
+
[](value_and_holder &v_h, Args... args) {
|
| 208 |
+
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
|
| 209 |
+
},
|
| 210 |
+
is_new_style_constructor(),
|
| 211 |
+
extra...);
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
template <
|
| 215 |
+
typename Class,
|
| 216 |
+
typename... Extra,
|
| 217 |
+
enable_if_t<Class::has_alias && std::is_constructible<Cpp<Class>, Args...>::value, int>
|
| 218 |
+
= 0>
|
| 219 |
+
static void execute(Class &cl, const Extra &...extra) {
|
| 220 |
+
cl.def(
|
| 221 |
+
"__init__",
|
| 222 |
+
[](value_and_holder &v_h, Args... args) {
|
| 223 |
+
if (Py_TYPE(v_h.inst) == v_h.type->type) {
|
| 224 |
+
v_h.value_ptr()
|
| 225 |
+
= construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
|
| 226 |
+
} else {
|
| 227 |
+
v_h.value_ptr()
|
| 228 |
+
= construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
|
| 229 |
+
}
|
| 230 |
+
},
|
| 231 |
+
is_new_style_constructor(),
|
| 232 |
+
extra...);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
template <
|
| 236 |
+
typename Class,
|
| 237 |
+
typename... Extra,
|
| 238 |
+
enable_if_t<Class::has_alias && !std::is_constructible<Cpp<Class>, Args...>::value, int>
|
| 239 |
+
= 0>
|
| 240 |
+
static void execute(Class &cl, const Extra &...extra) {
|
| 241 |
+
cl.def(
|
| 242 |
+
"__init__",
|
| 243 |
+
[](value_and_holder &v_h, Args... args) {
|
| 244 |
+
v_h.value_ptr()
|
| 245 |
+
= construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
|
| 246 |
+
},
|
| 247 |
+
is_new_style_constructor(),
|
| 248 |
+
extra...);
|
| 249 |
+
}
|
| 250 |
+
};
|
| 251 |
+
|
| 252 |
+
// Implementing class for py::init_alias<...>()
|
| 253 |
+
template <typename... Args>
|
| 254 |
+
struct alias_constructor {
|
| 255 |
+
template <
|
| 256 |
+
typename Class,
|
| 257 |
+
typename... Extra,
|
| 258 |
+
enable_if_t<Class::has_alias && std::is_constructible<Alias<Class>, Args...>::value, int>
|
| 259 |
+
= 0>
|
| 260 |
+
static void execute(Class &cl, const Extra &...extra) {
|
| 261 |
+
cl.def(
|
| 262 |
+
"__init__",
|
| 263 |
+
[](value_and_holder &v_h, Args... args) {
|
| 264 |
+
v_h.value_ptr()
|
| 265 |
+
= construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
|
| 266 |
+
},
|
| 267 |
+
is_new_style_constructor(),
|
| 268 |
+
extra...);
|
| 269 |
+
}
|
| 270 |
+
};
|
| 271 |
+
|
| 272 |
+
// Implementation class for py::init(Func) and py::init(Func, AliasFunc)
|
| 273 |
+
template <typename CFunc,
|
| 274 |
+
typename AFunc = void_type (*)(),
|
| 275 |
+
typename = function_signature_t<CFunc>,
|
| 276 |
+
typename = function_signature_t<AFunc>>
|
| 277 |
+
struct factory;
|
| 278 |
+
|
| 279 |
+
// Specialization for py::init(Func)
|
| 280 |
+
template <typename Func, typename Return, typename... Args>
|
| 281 |
+
struct factory<Func, void_type (*)(), Return(Args...)> {
|
| 282 |
+
remove_reference_t<Func> class_factory;
|
| 283 |
+
|
| 284 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 285 |
+
factory(Func &&f) : class_factory(std::forward<Func>(f)) {}
|
| 286 |
+
|
| 287 |
+
// The given class either has no alias or has no separate alias factory;
|
| 288 |
+
// this always constructs the class itself. If the class is registered with an alias
|
| 289 |
+
// type and an alias instance is needed (i.e. because the final type is a Python class
|
| 290 |
+
// inheriting from the C++ type) the returned value needs to either already be an alias
|
| 291 |
+
// instance, or the alias needs to be constructible from a `Class &&` argument.
|
| 292 |
+
template <typename Class, typename... Extra>
|
| 293 |
+
void execute(Class &cl, const Extra &...extra) && {
|
| 294 |
+
#if defined(PYBIND11_CPP14)
|
| 295 |
+
cl.def(
|
| 296 |
+
"__init__",
|
| 297 |
+
[func = std::move(class_factory)]
|
| 298 |
+
#else
|
| 299 |
+
auto &func = class_factory;
|
| 300 |
+
cl.def(
|
| 301 |
+
"__init__",
|
| 302 |
+
[func]
|
| 303 |
+
#endif
|
| 304 |
+
(value_and_holder &v_h, Args... args) {
|
| 305 |
+
construct<Class>(
|
| 306 |
+
v_h, func(std::forward<Args>(args)...), Py_TYPE(v_h.inst) != v_h.type->type);
|
| 307 |
+
},
|
| 308 |
+
is_new_style_constructor(),
|
| 309 |
+
extra...);
|
| 310 |
+
}
|
| 311 |
+
};
|
| 312 |
+
|
| 313 |
+
// Specialization for py::init(Func, AliasFunc)
|
| 314 |
+
template <typename CFunc,
|
| 315 |
+
typename AFunc,
|
| 316 |
+
typename CReturn,
|
| 317 |
+
typename... CArgs,
|
| 318 |
+
typename AReturn,
|
| 319 |
+
typename... AArgs>
|
| 320 |
+
struct factory<CFunc, AFunc, CReturn(CArgs...), AReturn(AArgs...)> {
|
| 321 |
+
static_assert(sizeof...(CArgs) == sizeof...(AArgs),
|
| 322 |
+
"pybind11::init(class_factory, alias_factory): class and alias factories "
|
| 323 |
+
"must have identical argument signatures");
|
| 324 |
+
static_assert(all_of<std::is_same<CArgs, AArgs>...>::value,
|
| 325 |
+
"pybind11::init(class_factory, alias_factory): class and alias factories "
|
| 326 |
+
"must have identical argument signatures");
|
| 327 |
+
|
| 328 |
+
remove_reference_t<CFunc> class_factory;
|
| 329 |
+
remove_reference_t<AFunc> alias_factory;
|
| 330 |
+
|
| 331 |
+
factory(CFunc &&c, AFunc &&a)
|
| 332 |
+
: class_factory(std::forward<CFunc>(c)), alias_factory(std::forward<AFunc>(a)) {}
|
| 333 |
+
|
| 334 |
+
// The class factory is called when the `self` type passed to `__init__` is the direct
|
| 335 |
+
// class (i.e. not inherited), the alias factory when `self` is a Python-side subtype.
|
| 336 |
+
template <typename Class, typename... Extra>
|
| 337 |
+
void execute(Class &cl, const Extra &...extra) && {
|
| 338 |
+
static_assert(Class::has_alias,
|
| 339 |
+
"The two-argument version of `py::init()` can "
|
| 340 |
+
"only be used if the class has an alias");
|
| 341 |
+
#if defined(PYBIND11_CPP14)
|
| 342 |
+
cl.def(
|
| 343 |
+
"__init__",
|
| 344 |
+
[class_func = std::move(class_factory), alias_func = std::move(alias_factory)]
|
| 345 |
+
#else
|
| 346 |
+
auto &class_func = class_factory;
|
| 347 |
+
auto &alias_func = alias_factory;
|
| 348 |
+
cl.def(
|
| 349 |
+
"__init__",
|
| 350 |
+
[class_func, alias_func]
|
| 351 |
+
#endif
|
| 352 |
+
(value_and_holder &v_h, CArgs... args) {
|
| 353 |
+
if (Py_TYPE(v_h.inst) == v_h.type->type) {
|
| 354 |
+
// If the instance type equals the registered type we don't have inheritance,
|
| 355 |
+
// so don't need the alias and can construct using the class function:
|
| 356 |
+
construct<Class>(v_h, class_func(std::forward<CArgs>(args)...), false);
|
| 357 |
+
} else {
|
| 358 |
+
construct<Class>(v_h, alias_func(std::forward<CArgs>(args)...), true);
|
| 359 |
+
}
|
| 360 |
+
},
|
| 361 |
+
is_new_style_constructor(),
|
| 362 |
+
extra...);
|
| 363 |
+
}
|
| 364 |
+
};
|
| 365 |
+
|
| 366 |
+
/// Set just the C++ state. Same as `__init__`.
|
| 367 |
+
template <typename Class, typename T>
|
| 368 |
+
void setstate(value_and_holder &v_h, T &&result, bool need_alias) {
|
| 369 |
+
construct<Class>(v_h, std::forward<T>(result), need_alias);
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
/// Set both the C++ and Python states
|
| 373 |
+
template <typename Class,
|
| 374 |
+
typename T,
|
| 375 |
+
typename O,
|
| 376 |
+
enable_if_t<std::is_convertible<O, handle>::value, int> = 0>
|
| 377 |
+
void setstate(value_and_holder &v_h, std::pair<T, O> &&result, bool need_alias) {
|
| 378 |
+
construct<Class>(v_h, std::move(result.first), need_alias);
|
| 379 |
+
auto d = handle(result.second);
|
| 380 |
+
if (PyDict_Check(d.ptr()) && PyDict_Size(d.ptr()) == 0) {
|
| 381 |
+
// Skipping setattr below, to not force use of py::dynamic_attr() for Class unnecessarily.
|
| 382 |
+
// See PR #2972 for details.
|
| 383 |
+
return;
|
| 384 |
+
}
|
| 385 |
+
setattr((PyObject *) v_h.inst, "__dict__", d);
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
/// Implementation for py::pickle(GetState, SetState)
|
| 389 |
+
template <typename Get,
|
| 390 |
+
typename Set,
|
| 391 |
+
typename = function_signature_t<Get>,
|
| 392 |
+
typename = function_signature_t<Set>>
|
| 393 |
+
struct pickle_factory;
|
| 394 |
+
|
| 395 |
+
template <typename Get,
|
| 396 |
+
typename Set,
|
| 397 |
+
typename RetState,
|
| 398 |
+
typename Self,
|
| 399 |
+
typename NewInstance,
|
| 400 |
+
typename ArgState>
|
| 401 |
+
struct pickle_factory<Get, Set, RetState(Self), NewInstance(ArgState)> {
|
| 402 |
+
static_assert(std::is_same<intrinsic_t<RetState>, intrinsic_t<ArgState>>::value,
|
| 403 |
+
"The type returned by `__getstate__` must be the same "
|
| 404 |
+
"as the argument accepted by `__setstate__`");
|
| 405 |
+
|
| 406 |
+
remove_reference_t<Get> get;
|
| 407 |
+
remove_reference_t<Set> set;
|
| 408 |
+
|
| 409 |
+
pickle_factory(Get get, Set set) : get(std::forward<Get>(get)), set(std::forward<Set>(set)) {}
|
| 410 |
+
|
| 411 |
+
template <typename Class, typename... Extra>
|
| 412 |
+
void execute(Class &cl, const Extra &...extra) && {
|
| 413 |
+
cl.def("__getstate__", std::move(get));
|
| 414 |
+
|
| 415 |
+
#if defined(PYBIND11_CPP14)
|
| 416 |
+
cl.def(
|
| 417 |
+
"__setstate__",
|
| 418 |
+
[func = std::move(set)]
|
| 419 |
+
#else
|
| 420 |
+
auto &func = set;
|
| 421 |
+
cl.def(
|
| 422 |
+
"__setstate__",
|
| 423 |
+
[func]
|
| 424 |
+
#endif
|
| 425 |
+
(value_and_holder &v_h, ArgState state) {
|
| 426 |
+
setstate<Class>(
|
| 427 |
+
v_h, func(std::forward<ArgState>(state)), Py_TYPE(v_h.inst) != v_h.type->type);
|
| 428 |
+
},
|
| 429 |
+
is_new_style_constructor(),
|
| 430 |
+
extra...);
|
| 431 |
+
}
|
| 432 |
+
};
|
| 433 |
+
|
| 434 |
+
PYBIND11_NAMESPACE_END(initimpl)
|
| 435 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 436 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/detail/internals.h
ADDED
|
@@ -0,0 +1,766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/detail/internals.h: Internal data structure and related functions
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "common.h"
|
| 13 |
+
|
| 14 |
+
#if defined(PYBIND11_SIMPLE_GIL_MANAGEMENT)
|
| 15 |
+
# include <pybind11/gil.h>
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
#include <pybind11/pytypes.h>
|
| 19 |
+
|
| 20 |
+
#include <exception>
|
| 21 |
+
#include <mutex>
|
| 22 |
+
#include <thread>
|
| 23 |
+
|
| 24 |
+
/// Tracks the `internals` and `type_info` ABI version independent of the main library version.
|
| 25 |
+
///
|
| 26 |
+
/// Some portions of the code use an ABI that is conditional depending on this
|
| 27 |
+
/// version number. That allows ABI-breaking changes to be "pre-implemented".
|
| 28 |
+
/// Once the default version number is incremented, the conditional logic that
|
| 29 |
+
/// no longer applies can be removed. Additionally, users that need not
|
| 30 |
+
/// maintain ABI compatibility can increase the version number in order to take
|
| 31 |
+
/// advantage of any functionality/efficiency improvements that depend on the
|
| 32 |
+
/// newer ABI.
|
| 33 |
+
///
|
| 34 |
+
/// WARNING: If you choose to manually increase the ABI version, note that
|
| 35 |
+
/// pybind11 may not be tested as thoroughly with a non-default ABI version, and
|
| 36 |
+
/// further ABI-incompatible changes may be made before the ABI is officially
|
| 37 |
+
/// changed to the new version.
|
| 38 |
+
#ifndef PYBIND11_INTERNALS_VERSION
|
| 39 |
+
# if PY_VERSION_HEX >= 0x030C0000 || defined(_MSC_VER)
|
| 40 |
+
// Version bump for Python 3.12+, before first 3.12 beta release.
|
| 41 |
+
// Version bump for MSVC piggy-backed on PR #4779. See comments there.
|
| 42 |
+
# define PYBIND11_INTERNALS_VERSION 5
|
| 43 |
+
# else
|
| 44 |
+
# define PYBIND11_INTERNALS_VERSION 4
|
| 45 |
+
# endif
|
| 46 |
+
#endif
|
| 47 |
+
|
| 48 |
+
// This requirement is mainly to reduce the support burden (see PR #4570).
|
| 49 |
+
static_assert(PY_VERSION_HEX < 0x030C0000 || PYBIND11_INTERNALS_VERSION >= 5,
|
| 50 |
+
"pybind11 ABI version 5 is the minimum for Python 3.12+");
|
| 51 |
+
|
| 52 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 53 |
+
|
| 54 |
+
using ExceptionTranslator = void (*)(std::exception_ptr);
|
| 55 |
+
|
| 56 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 57 |
+
|
| 58 |
+
constexpr const char *internals_function_record_capsule_name = "pybind11_function_record_capsule";
|
| 59 |
+
|
| 60 |
+
// Forward declarations
|
| 61 |
+
inline PyTypeObject *make_static_property_type();
|
| 62 |
+
inline PyTypeObject *make_default_metaclass();
|
| 63 |
+
inline PyObject *make_object_base_type(PyTypeObject *metaclass);
|
| 64 |
+
|
| 65 |
+
// The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new
|
| 66 |
+
// Thread Specific Storage (TSS) API.
|
| 67 |
+
// Avoid unnecessary allocation of `Py_tss_t`, since we cannot use
|
| 68 |
+
// `Py_LIMITED_API` anyway.
|
| 69 |
+
#if PYBIND11_INTERNALS_VERSION > 4
|
| 70 |
+
# define PYBIND11_TLS_KEY_REF Py_tss_t &
|
| 71 |
+
# if defined(__clang__)
|
| 72 |
+
# define PYBIND11_TLS_KEY_INIT(var) \
|
| 73 |
+
_Pragma("clang diagnostic push") /**/ \
|
| 74 |
+
_Pragma("clang diagnostic ignored \"-Wmissing-field-initializers\"") /**/ \
|
| 75 |
+
Py_tss_t var \
|
| 76 |
+
= Py_tss_NEEDS_INIT; \
|
| 77 |
+
_Pragma("clang diagnostic pop")
|
| 78 |
+
# elif defined(__GNUC__) && !defined(__INTEL_COMPILER)
|
| 79 |
+
# define PYBIND11_TLS_KEY_INIT(var) \
|
| 80 |
+
_Pragma("GCC diagnostic push") /**/ \
|
| 81 |
+
_Pragma("GCC diagnostic ignored \"-Wmissing-field-initializers\"") /**/ \
|
| 82 |
+
Py_tss_t var \
|
| 83 |
+
= Py_tss_NEEDS_INIT; \
|
| 84 |
+
_Pragma("GCC diagnostic pop")
|
| 85 |
+
# else
|
| 86 |
+
# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t var = Py_tss_NEEDS_INIT;
|
| 87 |
+
# endif
|
| 88 |
+
# define PYBIND11_TLS_KEY_CREATE(var) (PyThread_tss_create(&(var)) == 0)
|
| 89 |
+
# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get(&(key))
|
| 90 |
+
# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set(&(key), (value))
|
| 91 |
+
# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set(&(key), nullptr)
|
| 92 |
+
# define PYBIND11_TLS_FREE(key) PyThread_tss_delete(&(key))
|
| 93 |
+
#else
|
| 94 |
+
# define PYBIND11_TLS_KEY_REF Py_tss_t *
|
| 95 |
+
# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr;
|
| 96 |
+
# define PYBIND11_TLS_KEY_CREATE(var) \
|
| 97 |
+
(((var) = PyThread_tss_alloc()) != nullptr && (PyThread_tss_create((var)) == 0))
|
| 98 |
+
# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
|
| 99 |
+
# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value))
|
| 100 |
+
# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
|
| 101 |
+
# define PYBIND11_TLS_FREE(key) PyThread_tss_free(key)
|
| 102 |
+
#endif
|
| 103 |
+
|
| 104 |
+
// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly
|
| 105 |
+
// other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module
|
| 106 |
+
// even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under
|
| 107 |
+
// libstdc++, this doesn't happen: equality and the type_index hash are based on the type name,
|
| 108 |
+
// which works. If not under a known-good stl, provide our own name-based hash and equality
|
| 109 |
+
// functions that use the type name.
|
| 110 |
+
#if (PYBIND11_INTERNALS_VERSION <= 4 && defined(__GLIBCXX__)) \
|
| 111 |
+
|| (PYBIND11_INTERNALS_VERSION >= 5 && !defined(_LIBCPP_VERSION))
|
| 112 |
+
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; }
|
| 113 |
+
using type_hash = std::hash<std::type_index>;
|
| 114 |
+
using type_equal_to = std::equal_to<std::type_index>;
|
| 115 |
+
#else
|
| 116 |
+
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) {
|
| 117 |
+
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
struct type_hash {
|
| 121 |
+
size_t operator()(const std::type_index &t) const {
|
| 122 |
+
size_t hash = 5381;
|
| 123 |
+
const char *ptr = t.name();
|
| 124 |
+
while (auto c = static_cast<unsigned char>(*ptr++)) {
|
| 125 |
+
hash = (hash * 33) ^ c;
|
| 126 |
+
}
|
| 127 |
+
return hash;
|
| 128 |
+
}
|
| 129 |
+
};
|
| 130 |
+
|
| 131 |
+
struct type_equal_to {
|
| 132 |
+
bool operator()(const std::type_index &lhs, const std::type_index &rhs) const {
|
| 133 |
+
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
|
| 134 |
+
}
|
| 135 |
+
};
|
| 136 |
+
#endif
|
| 137 |
+
|
| 138 |
+
template <typename value_type>
|
| 139 |
+
using type_map = std::unordered_map<std::type_index, value_type, type_hash, type_equal_to>;
|
| 140 |
+
|
| 141 |
+
struct override_hash {
|
| 142 |
+
inline size_t operator()(const std::pair<const PyObject *, const char *> &v) const {
|
| 143 |
+
size_t value = std::hash<const void *>()(v.first);
|
| 144 |
+
value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value << 6) + (value >> 2);
|
| 145 |
+
return value;
|
| 146 |
+
}
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
using instance_map = std::unordered_multimap<const void *, instance *>;
|
| 150 |
+
|
| 151 |
+
#ifdef Py_GIL_DISABLED
|
| 152 |
+
// Wrapper around PyMutex to provide BasicLockable semantics
|
| 153 |
+
class pymutex {
|
| 154 |
+
PyMutex mutex;
|
| 155 |
+
|
| 156 |
+
public:
|
| 157 |
+
pymutex() : mutex({}) {}
|
| 158 |
+
void lock() { PyMutex_Lock(&mutex); }
|
| 159 |
+
void unlock() { PyMutex_Unlock(&mutex); }
|
| 160 |
+
};
|
| 161 |
+
|
| 162 |
+
// Instance map shards are used to reduce mutex contention in free-threaded Python.
|
| 163 |
+
struct instance_map_shard {
|
| 164 |
+
instance_map registered_instances;
|
| 165 |
+
pymutex mutex;
|
| 166 |
+
// alignas(64) would be better, but causes compile errors in macOS before 10.14 (see #5200)
|
| 167 |
+
char padding[64 - (sizeof(instance_map) + sizeof(pymutex)) % 64];
|
| 168 |
+
};
|
| 169 |
+
|
| 170 |
+
static_assert(sizeof(instance_map_shard) % 64 == 0,
|
| 171 |
+
"instance_map_shard size is not a multiple of 64 bytes");
|
| 172 |
+
#endif
|
| 173 |
+
|
| 174 |
+
/// Internal data structure used to track registered instances and types.
|
| 175 |
+
/// Whenever binary incompatible changes are made to this structure,
|
| 176 |
+
/// `PYBIND11_INTERNALS_VERSION` must be incremented.
|
| 177 |
+
struct internals {
|
| 178 |
+
#ifdef Py_GIL_DISABLED
|
| 179 |
+
pymutex mutex;
|
| 180 |
+
#endif
|
| 181 |
+
// std::type_index -> pybind11's type information
|
| 182 |
+
type_map<type_info *> registered_types_cpp;
|
| 183 |
+
// PyTypeObject* -> base type_info(s)
|
| 184 |
+
std::unordered_map<PyTypeObject *, std::vector<type_info *>> registered_types_py;
|
| 185 |
+
#ifdef Py_GIL_DISABLED
|
| 186 |
+
std::unique_ptr<instance_map_shard[]> instance_shards; // void * -> instance*
|
| 187 |
+
size_t instance_shards_mask;
|
| 188 |
+
#else
|
| 189 |
+
instance_map registered_instances; // void * -> instance*
|
| 190 |
+
#endif
|
| 191 |
+
std::unordered_set<std::pair<const PyObject *, const char *>, override_hash>
|
| 192 |
+
inactive_override_cache;
|
| 193 |
+
type_map<std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
|
| 194 |
+
std::unordered_map<const PyObject *, std::vector<PyObject *>> patients;
|
| 195 |
+
std::forward_list<ExceptionTranslator> registered_exception_translators;
|
| 196 |
+
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across
|
| 197 |
+
// extensions
|
| 198 |
+
#if PYBIND11_INTERNALS_VERSION == 4
|
| 199 |
+
std::vector<PyObject *> unused_loader_patient_stack_remove_at_v5;
|
| 200 |
+
#endif
|
| 201 |
+
std::forward_list<std::string> static_strings; // Stores the std::strings backing
|
| 202 |
+
// detail::c_str()
|
| 203 |
+
PyTypeObject *static_property_type;
|
| 204 |
+
PyTypeObject *default_metaclass;
|
| 205 |
+
PyObject *instance_base;
|
| 206 |
+
// Unused if PYBIND11_SIMPLE_GIL_MANAGEMENT is defined:
|
| 207 |
+
PYBIND11_TLS_KEY_INIT(tstate)
|
| 208 |
+
#if PYBIND11_INTERNALS_VERSION > 4
|
| 209 |
+
PYBIND11_TLS_KEY_INIT(loader_life_support_tls_key)
|
| 210 |
+
#endif // PYBIND11_INTERNALS_VERSION > 4
|
| 211 |
+
// Unused if PYBIND11_SIMPLE_GIL_MANAGEMENT is defined:
|
| 212 |
+
PyInterpreterState *istate = nullptr;
|
| 213 |
+
|
| 214 |
+
#if PYBIND11_INTERNALS_VERSION > 4
|
| 215 |
+
// Note that we have to use a std::string to allocate memory to ensure a unique address
|
| 216 |
+
// We want unique addresses since we use pointer equality to compare function records
|
| 217 |
+
std::string function_record_capsule_name = internals_function_record_capsule_name;
|
| 218 |
+
#endif
|
| 219 |
+
|
| 220 |
+
internals() = default;
|
| 221 |
+
internals(const internals &other) = delete;
|
| 222 |
+
internals &operator=(const internals &other) = delete;
|
| 223 |
+
~internals() {
|
| 224 |
+
#if PYBIND11_INTERNALS_VERSION > 4
|
| 225 |
+
PYBIND11_TLS_FREE(loader_life_support_tls_key);
|
| 226 |
+
#endif // PYBIND11_INTERNALS_VERSION > 4
|
| 227 |
+
|
| 228 |
+
// This destructor is called *after* Py_Finalize() in finalize_interpreter().
|
| 229 |
+
// That *SHOULD BE* fine. The following details what happens when PyThread_tss_free is
|
| 230 |
+
// called. PYBIND11_TLS_FREE is PyThread_tss_free on python 3.7+. On older python, it does
|
| 231 |
+
// nothing. PyThread_tss_free calls PyThread_tss_delete and PyMem_RawFree.
|
| 232 |
+
// PyThread_tss_delete just calls TlsFree (on Windows) or pthread_key_delete (on *NIX).
|
| 233 |
+
// Neither of those have anything to do with CPython internals. PyMem_RawFree *requires*
|
| 234 |
+
// that the `tstate` be allocated with the CPython allocator.
|
| 235 |
+
PYBIND11_TLS_FREE(tstate);
|
| 236 |
+
}
|
| 237 |
+
};
|
| 238 |
+
|
| 239 |
+
/// Additional type information which does not fit into the PyTypeObject.
|
| 240 |
+
/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`.
|
| 241 |
+
struct type_info {
|
| 242 |
+
PyTypeObject *type;
|
| 243 |
+
const std::type_info *cpptype;
|
| 244 |
+
size_t type_size, type_align, holder_size_in_ptrs;
|
| 245 |
+
void *(*operator_new)(size_t);
|
| 246 |
+
void (*init_instance)(instance *, const void *);
|
| 247 |
+
void (*dealloc)(value_and_holder &v_h);
|
| 248 |
+
std::vector<PyObject *(*) (PyObject *, PyTypeObject *)> implicit_conversions;
|
| 249 |
+
std::vector<std::pair<const std::type_info *, void *(*) (void *)>> implicit_casts;
|
| 250 |
+
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
|
| 251 |
+
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
|
| 252 |
+
void *get_buffer_data = nullptr;
|
| 253 |
+
void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
|
| 254 |
+
/* A simple type never occurs as a (direct or indirect) parent
|
| 255 |
+
* of a class that makes use of multiple inheritance.
|
| 256 |
+
* A type can be simple even if it has non-simple ancestors as long as it has no descendants.
|
| 257 |
+
*/
|
| 258 |
+
bool simple_type : 1;
|
| 259 |
+
/* True if there is no multiple inheritance in this type's inheritance tree */
|
| 260 |
+
bool simple_ancestors : 1;
|
| 261 |
+
/* for base vs derived holder_type checks */
|
| 262 |
+
bool default_holder : 1;
|
| 263 |
+
/* true if this is a type registered with py::module_local */
|
| 264 |
+
bool module_local : 1;
|
| 265 |
+
};
|
| 266 |
+
|
| 267 |
+
/// On MSVC, debug and release builds are not ABI-compatible!
|
| 268 |
+
#if defined(_MSC_VER) && defined(_DEBUG)
|
| 269 |
+
# define PYBIND11_BUILD_TYPE "_debug"
|
| 270 |
+
#else
|
| 271 |
+
# define PYBIND11_BUILD_TYPE ""
|
| 272 |
+
#endif
|
| 273 |
+
|
| 274 |
+
/// Let's assume that different compilers are ABI-incompatible.
|
| 275 |
+
/// A user can manually set this string if they know their
|
| 276 |
+
/// compiler is compatible.
|
| 277 |
+
#ifndef PYBIND11_COMPILER_TYPE
|
| 278 |
+
# if defined(_MSC_VER)
|
| 279 |
+
# define PYBIND11_COMPILER_TYPE "_msvc"
|
| 280 |
+
# elif defined(__INTEL_COMPILER)
|
| 281 |
+
# define PYBIND11_COMPILER_TYPE "_icc"
|
| 282 |
+
# elif defined(__clang__)
|
| 283 |
+
# define PYBIND11_COMPILER_TYPE "_clang"
|
| 284 |
+
# elif defined(__PGI)
|
| 285 |
+
# define PYBIND11_COMPILER_TYPE "_pgi"
|
| 286 |
+
# elif defined(__MINGW32__)
|
| 287 |
+
# define PYBIND11_COMPILER_TYPE "_mingw"
|
| 288 |
+
# elif defined(__CYGWIN__)
|
| 289 |
+
# define PYBIND11_COMPILER_TYPE "_gcc_cygwin"
|
| 290 |
+
# elif defined(__GNUC__)
|
| 291 |
+
# define PYBIND11_COMPILER_TYPE "_gcc"
|
| 292 |
+
# else
|
| 293 |
+
# define PYBIND11_COMPILER_TYPE "_unknown"
|
| 294 |
+
# endif
|
| 295 |
+
#endif
|
| 296 |
+
|
| 297 |
+
/// Also standard libs
|
| 298 |
+
#ifndef PYBIND11_STDLIB
|
| 299 |
+
# if defined(_LIBCPP_VERSION)
|
| 300 |
+
# define PYBIND11_STDLIB "_libcpp"
|
| 301 |
+
# elif defined(__GLIBCXX__) || defined(__GLIBCPP__)
|
| 302 |
+
# define PYBIND11_STDLIB "_libstdcpp"
|
| 303 |
+
# else
|
| 304 |
+
# define PYBIND11_STDLIB ""
|
| 305 |
+
# endif
|
| 306 |
+
#endif
|
| 307 |
+
|
| 308 |
+
/// On Linux/OSX, changes in __GXX_ABI_VERSION__ indicate ABI incompatibility.
|
| 309 |
+
/// On MSVC, changes in _MSC_VER may indicate ABI incompatibility (#2898).
|
| 310 |
+
#ifndef PYBIND11_BUILD_ABI
|
| 311 |
+
# if defined(__GXX_ABI_VERSION)
|
| 312 |
+
# define PYBIND11_BUILD_ABI "_cxxabi" PYBIND11_TOSTRING(__GXX_ABI_VERSION)
|
| 313 |
+
# elif defined(_MSC_VER)
|
| 314 |
+
# define PYBIND11_BUILD_ABI "_mscver" PYBIND11_TOSTRING(_MSC_VER)
|
| 315 |
+
# else
|
| 316 |
+
# define PYBIND11_BUILD_ABI ""
|
| 317 |
+
# endif
|
| 318 |
+
#endif
|
| 319 |
+
|
| 320 |
+
#ifndef PYBIND11_INTERNALS_KIND
|
| 321 |
+
# define PYBIND11_INTERNALS_KIND ""
|
| 322 |
+
#endif
|
| 323 |
+
|
| 324 |
+
#define PYBIND11_PLATFORM_ABI_ID \
|
| 325 |
+
PYBIND11_INTERNALS_KIND PYBIND11_COMPILER_TYPE PYBIND11_STDLIB PYBIND11_BUILD_ABI \
|
| 326 |
+
PYBIND11_BUILD_TYPE
|
| 327 |
+
|
| 328 |
+
#define PYBIND11_INTERNALS_ID \
|
| 329 |
+
"__pybind11_internals_v" PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) \
|
| 330 |
+
PYBIND11_PLATFORM_ABI_ID "__"
|
| 331 |
+
|
| 332 |
+
#define PYBIND11_MODULE_LOCAL_ID \
|
| 333 |
+
"__pybind11_module_local_v" PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) \
|
| 334 |
+
PYBIND11_PLATFORM_ABI_ID "__"
|
| 335 |
+
|
| 336 |
+
/// Each module locally stores a pointer to the `internals` data. The data
|
| 337 |
+
/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
|
| 338 |
+
inline internals **&get_internals_pp() {
|
| 339 |
+
static internals **internals_pp = nullptr;
|
| 340 |
+
return internals_pp;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
// forward decl
|
| 344 |
+
inline void translate_exception(std::exception_ptr);
|
| 345 |
+
|
| 346 |
+
template <class T,
|
| 347 |
+
enable_if_t<std::is_same<std::nested_exception, remove_cvref_t<T>>::value, int> = 0>
|
| 348 |
+
bool handle_nested_exception(const T &exc, const std::exception_ptr &p) {
|
| 349 |
+
std::exception_ptr nested = exc.nested_ptr();
|
| 350 |
+
if (nested != nullptr && nested != p) {
|
| 351 |
+
translate_exception(nested);
|
| 352 |
+
return true;
|
| 353 |
+
}
|
| 354 |
+
return false;
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
template <class T,
|
| 358 |
+
enable_if_t<!std::is_same<std::nested_exception, remove_cvref_t<T>>::value, int> = 0>
|
| 359 |
+
bool handle_nested_exception(const T &exc, const std::exception_ptr &p) {
|
| 360 |
+
if (const auto *nep = dynamic_cast<const std::nested_exception *>(std::addressof(exc))) {
|
| 361 |
+
return handle_nested_exception(*nep, p);
|
| 362 |
+
}
|
| 363 |
+
return false;
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
inline bool raise_err(PyObject *exc_type, const char *msg) {
|
| 367 |
+
if (PyErr_Occurred()) {
|
| 368 |
+
raise_from(exc_type, msg);
|
| 369 |
+
return true;
|
| 370 |
+
}
|
| 371 |
+
set_error(exc_type, msg);
|
| 372 |
+
return false;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
inline void translate_exception(std::exception_ptr p) {
|
| 376 |
+
if (!p) {
|
| 377 |
+
return;
|
| 378 |
+
}
|
| 379 |
+
try {
|
| 380 |
+
std::rethrow_exception(p);
|
| 381 |
+
} catch (error_already_set &e) {
|
| 382 |
+
handle_nested_exception(e, p);
|
| 383 |
+
e.restore();
|
| 384 |
+
return;
|
| 385 |
+
} catch (const builtin_exception &e) {
|
| 386 |
+
// Could not use template since it's an abstract class.
|
| 387 |
+
if (const auto *nep = dynamic_cast<const std::nested_exception *>(std::addressof(e))) {
|
| 388 |
+
handle_nested_exception(*nep, p);
|
| 389 |
+
}
|
| 390 |
+
e.set_error();
|
| 391 |
+
return;
|
| 392 |
+
} catch (const std::bad_alloc &e) {
|
| 393 |
+
handle_nested_exception(e, p);
|
| 394 |
+
raise_err(PyExc_MemoryError, e.what());
|
| 395 |
+
return;
|
| 396 |
+
} catch (const std::domain_error &e) {
|
| 397 |
+
handle_nested_exception(e, p);
|
| 398 |
+
raise_err(PyExc_ValueError, e.what());
|
| 399 |
+
return;
|
| 400 |
+
} catch (const std::invalid_argument &e) {
|
| 401 |
+
handle_nested_exception(e, p);
|
| 402 |
+
raise_err(PyExc_ValueError, e.what());
|
| 403 |
+
return;
|
| 404 |
+
} catch (const std::length_error &e) {
|
| 405 |
+
handle_nested_exception(e, p);
|
| 406 |
+
raise_err(PyExc_ValueError, e.what());
|
| 407 |
+
return;
|
| 408 |
+
} catch (const std::out_of_range &e) {
|
| 409 |
+
handle_nested_exception(e, p);
|
| 410 |
+
raise_err(PyExc_IndexError, e.what());
|
| 411 |
+
return;
|
| 412 |
+
} catch (const std::range_error &e) {
|
| 413 |
+
handle_nested_exception(e, p);
|
| 414 |
+
raise_err(PyExc_ValueError, e.what());
|
| 415 |
+
return;
|
| 416 |
+
} catch (const std::overflow_error &e) {
|
| 417 |
+
handle_nested_exception(e, p);
|
| 418 |
+
raise_err(PyExc_OverflowError, e.what());
|
| 419 |
+
return;
|
| 420 |
+
} catch (const std::exception &e) {
|
| 421 |
+
handle_nested_exception(e, p);
|
| 422 |
+
raise_err(PyExc_RuntimeError, e.what());
|
| 423 |
+
return;
|
| 424 |
+
} catch (const std::nested_exception &e) {
|
| 425 |
+
handle_nested_exception(e, p);
|
| 426 |
+
raise_err(PyExc_RuntimeError, "Caught an unknown nested exception!");
|
| 427 |
+
return;
|
| 428 |
+
} catch (...) {
|
| 429 |
+
raise_err(PyExc_RuntimeError, "Caught an unknown exception!");
|
| 430 |
+
return;
|
| 431 |
+
}
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
#if !defined(__GLIBCXX__)
|
| 435 |
+
inline void translate_local_exception(std::exception_ptr p) {
|
| 436 |
+
try {
|
| 437 |
+
if (p) {
|
| 438 |
+
std::rethrow_exception(p);
|
| 439 |
+
}
|
| 440 |
+
} catch (error_already_set &e) {
|
| 441 |
+
e.restore();
|
| 442 |
+
return;
|
| 443 |
+
} catch (const builtin_exception &e) {
|
| 444 |
+
e.set_error();
|
| 445 |
+
return;
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
#endif
|
| 449 |
+
|
| 450 |
+
inline object get_python_state_dict() {
|
| 451 |
+
object state_dict;
|
| 452 |
+
#if PYBIND11_INTERNALS_VERSION <= 4 || PY_VERSION_HEX < 0x03080000 || defined(PYPY_VERSION)
|
| 453 |
+
state_dict = reinterpret_borrow<object>(PyEval_GetBuiltins());
|
| 454 |
+
#else
|
| 455 |
+
# if PY_VERSION_HEX < 0x03090000
|
| 456 |
+
PyInterpreterState *istate = _PyInterpreterState_Get();
|
| 457 |
+
# else
|
| 458 |
+
PyInterpreterState *istate = PyInterpreterState_Get();
|
| 459 |
+
# endif
|
| 460 |
+
if (istate) {
|
| 461 |
+
state_dict = reinterpret_borrow<object>(PyInterpreterState_GetDict(istate));
|
| 462 |
+
}
|
| 463 |
+
#endif
|
| 464 |
+
if (!state_dict) {
|
| 465 |
+
raise_from(PyExc_SystemError, "pybind11::detail::get_python_state_dict() FAILED");
|
| 466 |
+
throw error_already_set();
|
| 467 |
+
}
|
| 468 |
+
return state_dict;
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
inline object get_internals_obj_from_state_dict(handle state_dict) {
|
| 472 |
+
return reinterpret_steal<object>(
|
| 473 |
+
dict_getitemstringref(state_dict.ptr(), PYBIND11_INTERNALS_ID));
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
inline internals **get_internals_pp_from_capsule(handle obj) {
|
| 477 |
+
void *raw_ptr = PyCapsule_GetPointer(obj.ptr(), /*name=*/nullptr);
|
| 478 |
+
if (raw_ptr == nullptr) {
|
| 479 |
+
raise_from(PyExc_SystemError, "pybind11::detail::get_internals_pp_from_capsule() FAILED");
|
| 480 |
+
throw error_already_set();
|
| 481 |
+
}
|
| 482 |
+
return static_cast<internals **>(raw_ptr);
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
inline uint64_t round_up_to_next_pow2(uint64_t x) {
|
| 486 |
+
// Round-up to the next power of two.
|
| 487 |
+
// See https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
|
| 488 |
+
x--;
|
| 489 |
+
x |= (x >> 1);
|
| 490 |
+
x |= (x >> 2);
|
| 491 |
+
x |= (x >> 4);
|
| 492 |
+
x |= (x >> 8);
|
| 493 |
+
x |= (x >> 16);
|
| 494 |
+
x |= (x >> 32);
|
| 495 |
+
x++;
|
| 496 |
+
return x;
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
/// Return a reference to the current `internals` data
|
| 500 |
+
PYBIND11_NOINLINE internals &get_internals() {
|
| 501 |
+
auto **&internals_pp = get_internals_pp();
|
| 502 |
+
if (internals_pp && *internals_pp) {
|
| 503 |
+
return **internals_pp;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
#if defined(PYBIND11_SIMPLE_GIL_MANAGEMENT)
|
| 507 |
+
gil_scoped_acquire gil;
|
| 508 |
+
#else
|
| 509 |
+
// Ensure that the GIL is held since we will need to make Python calls.
|
| 510 |
+
// Cannot use py::gil_scoped_acquire here since that constructor calls get_internals.
|
| 511 |
+
struct gil_scoped_acquire_local {
|
| 512 |
+
gil_scoped_acquire_local() : state(PyGILState_Ensure()) {}
|
| 513 |
+
gil_scoped_acquire_local(const gil_scoped_acquire_local &) = delete;
|
| 514 |
+
gil_scoped_acquire_local &operator=(const gil_scoped_acquire_local &) = delete;
|
| 515 |
+
~gil_scoped_acquire_local() { PyGILState_Release(state); }
|
| 516 |
+
const PyGILState_STATE state;
|
| 517 |
+
} gil;
|
| 518 |
+
#endif
|
| 519 |
+
error_scope err_scope;
|
| 520 |
+
|
| 521 |
+
dict state_dict = get_python_state_dict();
|
| 522 |
+
if (object internals_obj = get_internals_obj_from_state_dict(state_dict)) {
|
| 523 |
+
internals_pp = get_internals_pp_from_capsule(internals_obj);
|
| 524 |
+
}
|
| 525 |
+
if (internals_pp && *internals_pp) {
|
| 526 |
+
// We loaded the internals through `state_dict`, which means that our `error_already_set`
|
| 527 |
+
// and `builtin_exception` may be different local classes than the ones set up in the
|
| 528 |
+
// initial exception translator, below, so add another for our local exception classes.
|
| 529 |
+
//
|
| 530 |
+
// libstdc++ doesn't require this (types there are identified only by name)
|
| 531 |
+
// libc++ with CPython doesn't require this (types are explicitly exported)
|
| 532 |
+
// libc++ with PyPy still need it, awaiting further investigation
|
| 533 |
+
#if !defined(__GLIBCXX__)
|
| 534 |
+
(*internals_pp)->registered_exception_translators.push_front(&translate_local_exception);
|
| 535 |
+
#endif
|
| 536 |
+
} else {
|
| 537 |
+
if (!internals_pp) {
|
| 538 |
+
internals_pp = new internals *();
|
| 539 |
+
}
|
| 540 |
+
auto *&internals_ptr = *internals_pp;
|
| 541 |
+
internals_ptr = new internals();
|
| 542 |
+
|
| 543 |
+
PyThreadState *tstate = PyThreadState_Get();
|
| 544 |
+
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
|
| 545 |
+
if (!PYBIND11_TLS_KEY_CREATE(internals_ptr->tstate)) {
|
| 546 |
+
pybind11_fail("get_internals: could not successfully initialize the tstate TSS key!");
|
| 547 |
+
}
|
| 548 |
+
PYBIND11_TLS_REPLACE_VALUE(internals_ptr->tstate, tstate);
|
| 549 |
+
|
| 550 |
+
#if PYBIND11_INTERNALS_VERSION > 4
|
| 551 |
+
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
|
| 552 |
+
if (!PYBIND11_TLS_KEY_CREATE(internals_ptr->loader_life_support_tls_key)) {
|
| 553 |
+
pybind11_fail("get_internals: could not successfully initialize the "
|
| 554 |
+
"loader_life_support TSS key!");
|
| 555 |
+
}
|
| 556 |
+
#endif
|
| 557 |
+
internals_ptr->istate = tstate->interp;
|
| 558 |
+
state_dict[PYBIND11_INTERNALS_ID] = capsule(reinterpret_cast<void *>(internals_pp));
|
| 559 |
+
internals_ptr->registered_exception_translators.push_front(&translate_exception);
|
| 560 |
+
internals_ptr->static_property_type = make_static_property_type();
|
| 561 |
+
internals_ptr->default_metaclass = make_default_metaclass();
|
| 562 |
+
internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass);
|
| 563 |
+
#ifdef Py_GIL_DISABLED
|
| 564 |
+
// Scale proportional to the number of cores. 2x is a heuristic to reduce contention.
|
| 565 |
+
auto num_shards
|
| 566 |
+
= static_cast<size_t>(round_up_to_next_pow2(2 * std::thread::hardware_concurrency()));
|
| 567 |
+
if (num_shards == 0) {
|
| 568 |
+
num_shards = 1;
|
| 569 |
+
}
|
| 570 |
+
internals_ptr->instance_shards.reset(new instance_map_shard[num_shards]);
|
| 571 |
+
internals_ptr->instance_shards_mask = num_shards - 1;
|
| 572 |
+
#endif // Py_GIL_DISABLED
|
| 573 |
+
}
|
| 574 |
+
return **internals_pp;
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
// the internals struct (above) is shared between all the modules. local_internals are only
|
| 578 |
+
// for a single module. Any changes made to internals may require an update to
|
| 579 |
+
// PYBIND11_INTERNALS_VERSION, breaking backwards compatibility. local_internals is, by design,
|
| 580 |
+
// restricted to a single module. Whether a module has local internals or not should not
|
| 581 |
+
// impact any other modules, because the only things accessing the local internals is the
|
| 582 |
+
// module that contains them.
|
| 583 |
+
struct local_internals {
|
| 584 |
+
type_map<type_info *> registered_types_cpp;
|
| 585 |
+
std::forward_list<ExceptionTranslator> registered_exception_translators;
|
| 586 |
+
#if PYBIND11_INTERNALS_VERSION == 4
|
| 587 |
+
|
| 588 |
+
// For ABI compatibility, we can't store the loader_life_support TLS key in
|
| 589 |
+
// the `internals` struct directly. Instead, we store it in `shared_data` and
|
| 590 |
+
// cache a copy in `local_internals`. If we allocated a separate TLS key for
|
| 591 |
+
// each instance of `local_internals`, we could end up allocating hundreds of
|
| 592 |
+
// TLS keys if hundreds of different pybind11 modules are loaded (which is a
|
| 593 |
+
// plausible number).
|
| 594 |
+
PYBIND11_TLS_KEY_INIT(loader_life_support_tls_key)
|
| 595 |
+
|
| 596 |
+
// Holds the shared TLS key for the loader_life_support stack.
|
| 597 |
+
struct shared_loader_life_support_data {
|
| 598 |
+
PYBIND11_TLS_KEY_INIT(loader_life_support_tls_key)
|
| 599 |
+
shared_loader_life_support_data() {
|
| 600 |
+
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
|
| 601 |
+
if (!PYBIND11_TLS_KEY_CREATE(loader_life_support_tls_key)) {
|
| 602 |
+
pybind11_fail("local_internals: could not successfully initialize the "
|
| 603 |
+
"loader_life_support TLS key!");
|
| 604 |
+
}
|
| 605 |
+
}
|
| 606 |
+
// We can't help but leak the TLS key, because Python never unloads extension modules.
|
| 607 |
+
};
|
| 608 |
+
|
| 609 |
+
local_internals() {
|
| 610 |
+
auto &internals = get_internals();
|
| 611 |
+
// Get or create the `loader_life_support_stack_key`.
|
| 612 |
+
auto &ptr = internals.shared_data["_life_support"];
|
| 613 |
+
if (!ptr) {
|
| 614 |
+
ptr = new shared_loader_life_support_data;
|
| 615 |
+
}
|
| 616 |
+
loader_life_support_tls_key
|
| 617 |
+
= static_cast<shared_loader_life_support_data *>(ptr)->loader_life_support_tls_key;
|
| 618 |
+
}
|
| 619 |
+
#endif // PYBIND11_INTERNALS_VERSION == 4
|
| 620 |
+
};
|
| 621 |
+
|
| 622 |
+
/// Works like `get_internals`, but for things which are locally registered.
|
| 623 |
+
inline local_internals &get_local_internals() {
|
| 624 |
+
// Current static can be created in the interpreter finalization routine. If the later will be
|
| 625 |
+
// destroyed in another static variable destructor, creation of this static there will cause
|
| 626 |
+
// static deinitialization fiasco. In order to avoid it we avoid destruction of the
|
| 627 |
+
// local_internals static. One can read more about the problem and current solution here:
|
| 628 |
+
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
|
| 629 |
+
static auto *locals = new local_internals();
|
| 630 |
+
return *locals;
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
#ifdef Py_GIL_DISABLED
|
| 634 |
+
# define PYBIND11_LOCK_INTERNALS(internals) std::unique_lock<pymutex> lock((internals).mutex)
|
| 635 |
+
#else
|
| 636 |
+
# define PYBIND11_LOCK_INTERNALS(internals)
|
| 637 |
+
#endif
|
| 638 |
+
|
| 639 |
+
template <typename F>
|
| 640 |
+
inline auto with_internals(const F &cb) -> decltype(cb(get_internals())) {
|
| 641 |
+
auto &internals = get_internals();
|
| 642 |
+
PYBIND11_LOCK_INTERNALS(internals);
|
| 643 |
+
return cb(internals);
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
inline std::uint64_t mix64(std::uint64_t z) {
|
| 647 |
+
// David Stafford's variant 13 of the MurmurHash3 finalizer popularized
|
| 648 |
+
// by the SplitMix PRNG.
|
| 649 |
+
// https://zimbry.blogspot.com/2011/09/better-bit-mixing-improving-on.html
|
| 650 |
+
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9;
|
| 651 |
+
z = (z ^ (z >> 27)) * 0x94d049bb133111eb;
|
| 652 |
+
return z ^ (z >> 31);
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
template <typename F>
|
| 656 |
+
inline auto with_instance_map(const void *ptr,
|
| 657 |
+
const F &cb) -> decltype(cb(std::declval<instance_map &>())) {
|
| 658 |
+
auto &internals = get_internals();
|
| 659 |
+
|
| 660 |
+
#ifdef Py_GIL_DISABLED
|
| 661 |
+
// Hash address to compute shard, but ignore low bits. We'd like allocations
|
| 662 |
+
// from the same thread/core to map to the same shard and allocations from
|
| 663 |
+
// other threads/cores to map to other shards. Using the high bits is a good
|
| 664 |
+
// heuristic because memory allocators often have a per-thread
|
| 665 |
+
// arena/superblock/segment from which smaller allocations are served.
|
| 666 |
+
auto addr = reinterpret_cast<std::uintptr_t>(ptr);
|
| 667 |
+
auto hash = mix64(static_cast<std::uint64_t>(addr >> 20));
|
| 668 |
+
auto idx = static_cast<size_t>(hash & internals.instance_shards_mask);
|
| 669 |
+
|
| 670 |
+
auto &shard = internals.instance_shards[idx];
|
| 671 |
+
std::unique_lock<pymutex> lock(shard.mutex);
|
| 672 |
+
return cb(shard.registered_instances);
|
| 673 |
+
#else
|
| 674 |
+
(void) ptr;
|
| 675 |
+
return cb(internals.registered_instances);
|
| 676 |
+
#endif
|
| 677 |
+
}
|
| 678 |
+
|
| 679 |
+
// Returns the number of registered instances for testing purposes. The result may not be
|
| 680 |
+
// consistent if other threads are registering or unregistering instances concurrently.
|
| 681 |
+
inline size_t num_registered_instances() {
|
| 682 |
+
auto &internals = get_internals();
|
| 683 |
+
#ifdef Py_GIL_DISABLED
|
| 684 |
+
size_t count = 0;
|
| 685 |
+
for (size_t i = 0; i <= internals.instance_shards_mask; ++i) {
|
| 686 |
+
auto &shard = internals.instance_shards[i];
|
| 687 |
+
std::unique_lock<pymutex> lock(shard.mutex);
|
| 688 |
+
count += shard.registered_instances.size();
|
| 689 |
+
}
|
| 690 |
+
return count;
|
| 691 |
+
#else
|
| 692 |
+
return internals.registered_instances.size();
|
| 693 |
+
#endif
|
| 694 |
+
}
|
| 695 |
+
|
| 696 |
+
/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its
|
| 697 |
+
/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only
|
| 698 |
+
/// cleared when the program exits or after interpreter shutdown (when embedding), and so are
|
| 699 |
+
/// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name).
|
| 700 |
+
template <typename... Args>
|
| 701 |
+
const char *c_str(Args &&...args) {
|
| 702 |
+
// GCC 4.8 doesn't like parameter unpack within lambda capture, so use
|
| 703 |
+
// PYBIND11_LOCK_INTERNALS.
|
| 704 |
+
auto &internals = get_internals();
|
| 705 |
+
PYBIND11_LOCK_INTERNALS(internals);
|
| 706 |
+
auto &strings = internals.static_strings;
|
| 707 |
+
strings.emplace_front(std::forward<Args>(args)...);
|
| 708 |
+
return strings.front().c_str();
|
| 709 |
+
}
|
| 710 |
+
|
| 711 |
+
inline const char *get_function_record_capsule_name() {
|
| 712 |
+
#if PYBIND11_INTERNALS_VERSION > 4
|
| 713 |
+
return get_internals().function_record_capsule_name.c_str();
|
| 714 |
+
#else
|
| 715 |
+
return nullptr;
|
| 716 |
+
#endif
|
| 717 |
+
}
|
| 718 |
+
|
| 719 |
+
// Determine whether or not the following capsule contains a pybind11 function record.
|
| 720 |
+
// Note that we use `internals` to make sure that only ABI compatible records are touched.
|
| 721 |
+
//
|
| 722 |
+
// This check is currently used in two places:
|
| 723 |
+
// - An important optimization in functional.h to avoid overhead in C++ -> Python -> C++
|
| 724 |
+
// - The sibling feature of cpp_function to allow overloads
|
| 725 |
+
inline bool is_function_record_capsule(const capsule &cap) {
|
| 726 |
+
// Pointer equality as we rely on internals() to ensure unique pointers
|
| 727 |
+
return cap.name() == get_function_record_capsule_name();
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 731 |
+
|
| 732 |
+
/// Returns a named pointer that is shared among all extension modules (using the same
|
| 733 |
+
/// pybind11 version) running in the current interpreter. Names starting with underscores
|
| 734 |
+
/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
|
| 735 |
+
PYBIND11_NOINLINE void *get_shared_data(const std::string &name) {
|
| 736 |
+
return detail::with_internals([&](detail::internals &internals) {
|
| 737 |
+
auto it = internals.shared_data.find(name);
|
| 738 |
+
return it != internals.shared_data.end() ? it->second : nullptr;
|
| 739 |
+
});
|
| 740 |
+
}
|
| 741 |
+
|
| 742 |
+
/// Set the shared data that can be later recovered by `get_shared_data()`.
|
| 743 |
+
PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) {
|
| 744 |
+
return detail::with_internals([&](detail::internals &internals) {
|
| 745 |
+
internals.shared_data[name] = data;
|
| 746 |
+
return data;
|
| 747 |
+
});
|
| 748 |
+
}
|
| 749 |
+
|
| 750 |
+
/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
|
| 751 |
+
/// such entry exists. Otherwise, a new object of default-constructible type `T` is
|
| 752 |
+
/// added to the shared data under the given name and a reference to it is returned.
|
| 753 |
+
template <typename T>
|
| 754 |
+
T &get_or_create_shared_data(const std::string &name) {
|
| 755 |
+
return *detail::with_internals([&](detail::internals &internals) {
|
| 756 |
+
auto it = internals.shared_data.find(name);
|
| 757 |
+
T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr);
|
| 758 |
+
if (!ptr) {
|
| 759 |
+
ptr = new T();
|
| 760 |
+
internals.shared_data[name] = ptr;
|
| 761 |
+
}
|
| 762 |
+
return ptr;
|
| 763 |
+
});
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/detail/type_caster_base.h
ADDED
|
@@ -0,0 +1,1195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/detail/type_caster_base.h (originally first part of pybind11/cast.h)
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <pybind11/pytypes.h>
|
| 13 |
+
|
| 14 |
+
#include "common.h"
|
| 15 |
+
#include "cpp_conduit.h"
|
| 16 |
+
#include "descr.h"
|
| 17 |
+
#include "internals.h"
|
| 18 |
+
#include "typeid.h"
|
| 19 |
+
#include "value_and_holder.h"
|
| 20 |
+
|
| 21 |
+
#include <cstdint>
|
| 22 |
+
#include <cstring>
|
| 23 |
+
#include <iterator>
|
| 24 |
+
#include <new>
|
| 25 |
+
#include <stdexcept>
|
| 26 |
+
#include <string>
|
| 27 |
+
#include <type_traits>
|
| 28 |
+
#include <typeindex>
|
| 29 |
+
#include <typeinfo>
|
| 30 |
+
#include <unordered_map>
|
| 31 |
+
#include <utility>
|
| 32 |
+
#include <vector>
|
| 33 |
+
|
| 34 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 35 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 36 |
+
|
| 37 |
+
/// A life support system for temporary objects created by `type_caster::load()`.
|
| 38 |
+
/// Adding a patient will keep it alive up until the enclosing function returns.
|
| 39 |
+
class loader_life_support {
|
| 40 |
+
private:
|
| 41 |
+
loader_life_support *parent = nullptr;
|
| 42 |
+
std::unordered_set<PyObject *> keep_alive;
|
| 43 |
+
|
| 44 |
+
// Store stack pointer in thread-local storage.
|
| 45 |
+
static PYBIND11_TLS_KEY_REF get_stack_tls_key() {
|
| 46 |
+
#if PYBIND11_INTERNALS_VERSION == 4
|
| 47 |
+
return get_local_internals().loader_life_support_tls_key;
|
| 48 |
+
#else
|
| 49 |
+
return get_internals().loader_life_support_tls_key;
|
| 50 |
+
#endif
|
| 51 |
+
}
|
| 52 |
+
static loader_life_support *get_stack_top() {
|
| 53 |
+
return static_cast<loader_life_support *>(PYBIND11_TLS_GET_VALUE(get_stack_tls_key()));
|
| 54 |
+
}
|
| 55 |
+
static void set_stack_top(loader_life_support *value) {
|
| 56 |
+
PYBIND11_TLS_REPLACE_VALUE(get_stack_tls_key(), value);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
public:
|
| 60 |
+
/// A new patient frame is created when a function is entered
|
| 61 |
+
loader_life_support() : parent{get_stack_top()} { set_stack_top(this); }
|
| 62 |
+
|
| 63 |
+
/// ... and destroyed after it returns
|
| 64 |
+
~loader_life_support() {
|
| 65 |
+
if (get_stack_top() != this) {
|
| 66 |
+
pybind11_fail("loader_life_support: internal error");
|
| 67 |
+
}
|
| 68 |
+
set_stack_top(parent);
|
| 69 |
+
for (auto *item : keep_alive) {
|
| 70 |
+
Py_DECREF(item);
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
/// This can only be used inside a pybind11-bound function, either by `argument_loader`
|
| 75 |
+
/// at argument preparation time or by `py::cast()` at execution time.
|
| 76 |
+
PYBIND11_NOINLINE static void add_patient(handle h) {
|
| 77 |
+
loader_life_support *frame = get_stack_top();
|
| 78 |
+
if (!frame) {
|
| 79 |
+
// NOTE: It would be nice to include the stack frames here, as this indicates
|
| 80 |
+
// use of pybind11::cast<> outside the normal call framework, finding such
|
| 81 |
+
// a location is challenging. Developers could consider printing out
|
| 82 |
+
// stack frame addresses here using something like __builtin_frame_address(0)
|
| 83 |
+
throw cast_error("When called outside a bound function, py::cast() cannot "
|
| 84 |
+
"do Python -> C++ conversions which require the creation "
|
| 85 |
+
"of temporary values");
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
if (frame->keep_alive.insert(h.ptr()).second) {
|
| 89 |
+
Py_INCREF(h.ptr());
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
// Gets the cache entry for the given type, creating it if necessary. The return value is the pair
|
| 95 |
+
// returned by emplace, i.e. an iterator for the entry and a bool set to `true` if the entry was
|
| 96 |
+
// just created.
|
| 97 |
+
inline std::pair<decltype(internals::registered_types_py)::iterator, bool>
|
| 98 |
+
all_type_info_get_cache(PyTypeObject *type);
|
| 99 |
+
|
| 100 |
+
// Band-aid workaround to fix a subtle but serious bug in a minimalistic fashion. See PR #4762.
|
| 101 |
+
inline void all_type_info_add_base_most_derived_first(std::vector<type_info *> &bases,
|
| 102 |
+
type_info *addl_base) {
|
| 103 |
+
for (auto it = bases.begin(); it != bases.end(); it++) {
|
| 104 |
+
type_info *existing_base = *it;
|
| 105 |
+
if (PyType_IsSubtype(addl_base->type, existing_base->type) != 0) {
|
| 106 |
+
bases.insert(it, addl_base);
|
| 107 |
+
return;
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
bases.push_back(addl_base);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// Populates a just-created cache entry.
|
| 114 |
+
PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_info *> &bases) {
|
| 115 |
+
assert(bases.empty());
|
| 116 |
+
std::vector<PyTypeObject *> check;
|
| 117 |
+
for (handle parent : reinterpret_borrow<tuple>(t->tp_bases)) {
|
| 118 |
+
check.push_back((PyTypeObject *) parent.ptr());
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
auto const &type_dict = get_internals().registered_types_py;
|
| 122 |
+
for (size_t i = 0; i < check.size(); i++) {
|
| 123 |
+
auto *type = check[i];
|
| 124 |
+
// Ignore Python2 old-style class super type:
|
| 125 |
+
if (!PyType_Check((PyObject *) type)) {
|
| 126 |
+
continue;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// Check `type` in the current set of registered python types:
|
| 130 |
+
auto it = type_dict.find(type);
|
| 131 |
+
if (it != type_dict.end()) {
|
| 132 |
+
// We found a cache entry for it, so it's either pybind-registered or has pre-computed
|
| 133 |
+
// pybind bases, but we have to make sure we haven't already seen the type(s) before:
|
| 134 |
+
// we want to follow Python/virtual C++ rules that there should only be one instance of
|
| 135 |
+
// a common base.
|
| 136 |
+
for (auto *tinfo : it->second) {
|
| 137 |
+
// NB: Could use a second set here, rather than doing a linear search, but since
|
| 138 |
+
// having a large number of immediate pybind11-registered types seems fairly
|
| 139 |
+
// unlikely, that probably isn't worthwhile.
|
| 140 |
+
bool found = false;
|
| 141 |
+
for (auto *known : bases) {
|
| 142 |
+
if (known == tinfo) {
|
| 143 |
+
found = true;
|
| 144 |
+
break;
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
if (!found) {
|
| 148 |
+
all_type_info_add_base_most_derived_first(bases, tinfo);
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
} else if (type->tp_bases) {
|
| 152 |
+
// It's some python type, so keep follow its bases classes to look for one or more
|
| 153 |
+
// registered types
|
| 154 |
+
if (i + 1 == check.size()) {
|
| 155 |
+
// When we're at the end, we can pop off the current element to avoid growing
|
| 156 |
+
// `check` when adding just one base (which is typical--i.e. when there is no
|
| 157 |
+
// multiple inheritance)
|
| 158 |
+
check.pop_back();
|
| 159 |
+
i--;
|
| 160 |
+
}
|
| 161 |
+
for (handle parent : reinterpret_borrow<tuple>(type->tp_bases)) {
|
| 162 |
+
check.push_back((PyTypeObject *) parent.ptr());
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
/**
|
| 169 |
+
* Extracts vector of type_info pointers of pybind-registered roots of the given Python type. Will
|
| 170 |
+
* be just 1 pybind type for the Python type of a pybind-registered class, or for any Python-side
|
| 171 |
+
* derived class that uses single inheritance. Will contain as many types as required for a Python
|
| 172 |
+
* class that uses multiple inheritance to inherit (directly or indirectly) from multiple
|
| 173 |
+
* pybind-registered classes. Will be empty if neither the type nor any base classes are
|
| 174 |
+
* pybind-registered.
|
| 175 |
+
*
|
| 176 |
+
* The value is cached for the lifetime of the Python type.
|
| 177 |
+
*/
|
| 178 |
+
inline const std::vector<detail::type_info *> &all_type_info(PyTypeObject *type) {
|
| 179 |
+
auto ins = all_type_info_get_cache(type);
|
| 180 |
+
if (ins.second) {
|
| 181 |
+
// New cache entry: populate it
|
| 182 |
+
all_type_info_populate(type, ins.first->second);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
return ins.first->second;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
/**
|
| 189 |
+
* Gets a single pybind11 type info for a python type. Returns nullptr if neither the type nor any
|
| 190 |
+
* ancestors are pybind11-registered. Throws an exception if there are multiple bases--use
|
| 191 |
+
* `all_type_info` instead if you want to support multiple bases.
|
| 192 |
+
*/
|
| 193 |
+
PYBIND11_NOINLINE detail::type_info *get_type_info(PyTypeObject *type) {
|
| 194 |
+
const auto &bases = all_type_info(type);
|
| 195 |
+
if (bases.empty()) {
|
| 196 |
+
return nullptr;
|
| 197 |
+
}
|
| 198 |
+
if (bases.size() > 1) {
|
| 199 |
+
pybind11_fail(
|
| 200 |
+
"pybind11::detail::get_type_info: type has multiple pybind11-registered bases");
|
| 201 |
+
}
|
| 202 |
+
return bases.front();
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
inline detail::type_info *get_local_type_info(const std::type_index &tp) {
|
| 206 |
+
auto &locals = get_local_internals().registered_types_cpp;
|
| 207 |
+
auto it = locals.find(tp);
|
| 208 |
+
if (it != locals.end()) {
|
| 209 |
+
return it->second;
|
| 210 |
+
}
|
| 211 |
+
return nullptr;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
inline detail::type_info *get_global_type_info(const std::type_index &tp) {
|
| 215 |
+
return with_internals([&](internals &internals) {
|
| 216 |
+
detail::type_info *type_info = nullptr;
|
| 217 |
+
auto &types = internals.registered_types_cpp;
|
| 218 |
+
auto it = types.find(tp);
|
| 219 |
+
if (it != types.end()) {
|
| 220 |
+
type_info = it->second;
|
| 221 |
+
}
|
| 222 |
+
return type_info;
|
| 223 |
+
});
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
/// Return the type info for a given C++ type; on lookup failure can either throw or return
|
| 227 |
+
/// nullptr.
|
| 228 |
+
PYBIND11_NOINLINE detail::type_info *get_type_info(const std::type_index &tp,
|
| 229 |
+
bool throw_if_missing = false) {
|
| 230 |
+
if (auto *ltype = get_local_type_info(tp)) {
|
| 231 |
+
return ltype;
|
| 232 |
+
}
|
| 233 |
+
if (auto *gtype = get_global_type_info(tp)) {
|
| 234 |
+
return gtype;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
if (throw_if_missing) {
|
| 238 |
+
std::string tname = tp.name();
|
| 239 |
+
detail::clean_type_id(tname);
|
| 240 |
+
pybind11_fail("pybind11::detail::get_type_info: unable to find type info for \""
|
| 241 |
+
+ std::move(tname) + '"');
|
| 242 |
+
}
|
| 243 |
+
return nullptr;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
PYBIND11_NOINLINE handle get_type_handle(const std::type_info &tp, bool throw_if_missing) {
|
| 247 |
+
detail::type_info *type_info = get_type_info(tp, throw_if_missing);
|
| 248 |
+
return handle(type_info ? ((PyObject *) type_info->type) : nullptr);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
// Searches the inheritance graph for a registered Python instance, using all_type_info().
|
| 252 |
+
PYBIND11_NOINLINE handle find_registered_python_instance(void *src,
|
| 253 |
+
const detail::type_info *tinfo) {
|
| 254 |
+
return with_instance_map(src, [&](instance_map &instances) {
|
| 255 |
+
auto it_instances = instances.equal_range(src);
|
| 256 |
+
for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) {
|
| 257 |
+
for (auto *instance_type : detail::all_type_info(Py_TYPE(it_i->second))) {
|
| 258 |
+
if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype)) {
|
| 259 |
+
return handle((PyObject *) it_i->second).inc_ref();
|
| 260 |
+
}
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
return handle();
|
| 264 |
+
});
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
// Container for accessing and iterating over an instance's values/holders
|
| 268 |
+
struct values_and_holders {
|
| 269 |
+
private:
|
| 270 |
+
instance *inst;
|
| 271 |
+
using type_vec = std::vector<detail::type_info *>;
|
| 272 |
+
const type_vec &tinfo;
|
| 273 |
+
|
| 274 |
+
public:
|
| 275 |
+
explicit values_and_holders(instance *inst)
|
| 276 |
+
: inst{inst}, tinfo(all_type_info(Py_TYPE(inst))) {}
|
| 277 |
+
|
| 278 |
+
explicit values_and_holders(PyObject *obj)
|
| 279 |
+
: inst{nullptr}, tinfo(all_type_info(Py_TYPE(obj))) {
|
| 280 |
+
if (!tinfo.empty()) {
|
| 281 |
+
inst = reinterpret_cast<instance *>(obj);
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
struct iterator {
|
| 286 |
+
private:
|
| 287 |
+
instance *inst = nullptr;
|
| 288 |
+
const type_vec *types = nullptr;
|
| 289 |
+
value_and_holder curr;
|
| 290 |
+
friend struct values_and_holders;
|
| 291 |
+
iterator(instance *inst, const type_vec *tinfo) : inst{inst}, types{tinfo} {
|
| 292 |
+
if (inst != nullptr) {
|
| 293 |
+
assert(!types->empty());
|
| 294 |
+
curr = value_and_holder(
|
| 295 |
+
inst /* instance */,
|
| 296 |
+
(*types)[0] /* type info */,
|
| 297 |
+
0, /* vpos: (non-simple types only): the first vptr comes first */
|
| 298 |
+
0 /* index */);
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
// Past-the-end iterator:
|
| 302 |
+
explicit iterator(size_t end) : curr(end) {}
|
| 303 |
+
|
| 304 |
+
public:
|
| 305 |
+
bool operator==(const iterator &other) const { return curr.index == other.curr.index; }
|
| 306 |
+
bool operator!=(const iterator &other) const { return curr.index != other.curr.index; }
|
| 307 |
+
iterator &operator++() {
|
| 308 |
+
if (!inst->simple_layout) {
|
| 309 |
+
curr.vh += 1 + (*types)[curr.index]->holder_size_in_ptrs;
|
| 310 |
+
}
|
| 311 |
+
++curr.index;
|
| 312 |
+
curr.type = curr.index < types->size() ? (*types)[curr.index] : nullptr;
|
| 313 |
+
return *this;
|
| 314 |
+
}
|
| 315 |
+
value_and_holder &operator*() { return curr; }
|
| 316 |
+
value_and_holder *operator->() { return &curr; }
|
| 317 |
+
};
|
| 318 |
+
|
| 319 |
+
iterator begin() { return iterator(inst, &tinfo); }
|
| 320 |
+
iterator end() { return iterator(tinfo.size()); }
|
| 321 |
+
|
| 322 |
+
iterator find(const type_info *find_type) {
|
| 323 |
+
auto it = begin(), endit = end();
|
| 324 |
+
while (it != endit && it->type != find_type) {
|
| 325 |
+
++it;
|
| 326 |
+
}
|
| 327 |
+
return it;
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
size_t size() { return tinfo.size(); }
|
| 331 |
+
|
| 332 |
+
// Band-aid workaround to fix a subtle but serious bug in a minimalistic fashion. See PR #4762.
|
| 333 |
+
bool is_redundant_value_and_holder(const value_and_holder &vh) {
|
| 334 |
+
for (size_t i = 0; i < vh.index; i++) {
|
| 335 |
+
if (PyType_IsSubtype(tinfo[i]->type, tinfo[vh.index]->type) != 0) {
|
| 336 |
+
return true;
|
| 337 |
+
}
|
| 338 |
+
}
|
| 339 |
+
return false;
|
| 340 |
+
}
|
| 341 |
+
};
|
| 342 |
+
|
| 343 |
+
/**
|
| 344 |
+
* Extracts C++ value and holder pointer references from an instance (which may contain multiple
|
| 345 |
+
* values/holders for python-side multiple inheritance) that match the given type. Throws an error
|
| 346 |
+
* if the given type (or ValueType, if omitted) is not a pybind11 base of the given instance. If
|
| 347 |
+
* `find_type` is omitted (or explicitly specified as nullptr) the first value/holder are returned,
|
| 348 |
+
* regardless of type (and the resulting .type will be nullptr).
|
| 349 |
+
*
|
| 350 |
+
* The returned object should be short-lived: in particular, it must not outlive the called-upon
|
| 351 |
+
* instance.
|
| 352 |
+
*/
|
| 353 |
+
PYBIND11_NOINLINE value_and_holder
|
| 354 |
+
instance::get_value_and_holder(const type_info *find_type /*= nullptr default in common.h*/,
|
| 355 |
+
bool throw_if_missing /*= true in common.h*/) {
|
| 356 |
+
// Optimize common case:
|
| 357 |
+
if (!find_type || Py_TYPE(this) == find_type->type) {
|
| 358 |
+
return value_and_holder(this, find_type, 0, 0);
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
detail::values_and_holders vhs(this);
|
| 362 |
+
auto it = vhs.find(find_type);
|
| 363 |
+
if (it != vhs.end()) {
|
| 364 |
+
return *it;
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
if (!throw_if_missing) {
|
| 368 |
+
return value_and_holder();
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
#if defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 372 |
+
pybind11_fail("pybind11::detail::instance::get_value_and_holder: `"
|
| 373 |
+
+ get_fully_qualified_tp_name(find_type->type)
|
| 374 |
+
+ "' is not a pybind11 base of the given `"
|
| 375 |
+
+ get_fully_qualified_tp_name(Py_TYPE(this)) + "' instance");
|
| 376 |
+
#else
|
| 377 |
+
pybind11_fail(
|
| 378 |
+
"pybind11::detail::instance::get_value_and_holder: "
|
| 379 |
+
"type is not a pybind11 base of the given instance "
|
| 380 |
+
"(#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for type details)");
|
| 381 |
+
#endif
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
PYBIND11_NOINLINE void instance::allocate_layout() {
|
| 385 |
+
const auto &tinfo = all_type_info(Py_TYPE(this));
|
| 386 |
+
|
| 387 |
+
const size_t n_types = tinfo.size();
|
| 388 |
+
|
| 389 |
+
if (n_types == 0) {
|
| 390 |
+
pybind11_fail(
|
| 391 |
+
"instance allocation failed: new instance has no pybind11-registered base types");
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
simple_layout
|
| 395 |
+
= n_types == 1 && tinfo.front()->holder_size_in_ptrs <= instance_simple_holder_in_ptrs();
|
| 396 |
+
|
| 397 |
+
// Simple path: no python-side multiple inheritance, and a small-enough holder
|
| 398 |
+
if (simple_layout) {
|
| 399 |
+
simple_value_holder[0] = nullptr;
|
| 400 |
+
simple_holder_constructed = false;
|
| 401 |
+
simple_instance_registered = false;
|
| 402 |
+
} else { // multiple base types or a too-large holder
|
| 403 |
+
// Allocate space to hold: [v1*][h1][v2*][h2]...[bb...] where [vN*] is a value pointer,
|
| 404 |
+
// [hN] is the (uninitialized) holder instance for value N, and [bb...] is a set of bool
|
| 405 |
+
// values that tracks whether each associated holder has been initialized. Each [block] is
|
| 406 |
+
// padded, if necessary, to an integer multiple of sizeof(void *).
|
| 407 |
+
size_t space = 0;
|
| 408 |
+
for (auto *t : tinfo) {
|
| 409 |
+
space += 1; // value pointer
|
| 410 |
+
space += t->holder_size_in_ptrs; // holder instance
|
| 411 |
+
}
|
| 412 |
+
size_t flags_at = space;
|
| 413 |
+
space += size_in_ptrs(n_types); // status bytes (holder_constructed and
|
| 414 |
+
// instance_registered)
|
| 415 |
+
|
| 416 |
+
// Allocate space for flags, values, and holders, and initialize it to 0 (flags and values,
|
| 417 |
+
// in particular, need to be 0). Use Python's memory allocation
|
| 418 |
+
// functions: Python is using pymalloc, which is designed to be
|
| 419 |
+
// efficient for small allocations like the one we're doing here;
|
| 420 |
+
// for larger allocations they are just wrappers around malloc.
|
| 421 |
+
// TODO: is this still true for pure Python 3.6?
|
| 422 |
+
nonsimple.values_and_holders = (void **) PyMem_Calloc(space, sizeof(void *));
|
| 423 |
+
if (!nonsimple.values_and_holders) {
|
| 424 |
+
throw std::bad_alloc();
|
| 425 |
+
}
|
| 426 |
+
nonsimple.status
|
| 427 |
+
= reinterpret_cast<std::uint8_t *>(&nonsimple.values_and_holders[flags_at]);
|
| 428 |
+
}
|
| 429 |
+
owned = true;
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
// NOLINTNEXTLINE(readability-make-member-function-const)
|
| 433 |
+
PYBIND11_NOINLINE void instance::deallocate_layout() {
|
| 434 |
+
if (!simple_layout) {
|
| 435 |
+
PyMem_Free(reinterpret_cast<void *>(nonsimple.values_and_holders));
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
PYBIND11_NOINLINE bool isinstance_generic(handle obj, const std::type_info &tp) {
|
| 440 |
+
handle type = detail::get_type_handle(tp, false);
|
| 441 |
+
if (!type) {
|
| 442 |
+
return false;
|
| 443 |
+
}
|
| 444 |
+
return isinstance(obj, type);
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
PYBIND11_NOINLINE handle get_object_handle(const void *ptr, const detail::type_info *type) {
|
| 448 |
+
return with_instance_map(ptr, [&](instance_map &instances) {
|
| 449 |
+
auto range = instances.equal_range(ptr);
|
| 450 |
+
for (auto it = range.first; it != range.second; ++it) {
|
| 451 |
+
for (const auto &vh : values_and_holders(it->second)) {
|
| 452 |
+
if (vh.type == type) {
|
| 453 |
+
return handle((PyObject *) it->second);
|
| 454 |
+
}
|
| 455 |
+
}
|
| 456 |
+
}
|
| 457 |
+
return handle();
|
| 458 |
+
});
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
inline PyThreadState *get_thread_state_unchecked() {
|
| 462 |
+
#if defined(PYPY_VERSION)
|
| 463 |
+
return PyThreadState_GET();
|
| 464 |
+
#elif PY_VERSION_HEX < 0x030D0000
|
| 465 |
+
return _PyThreadState_UncheckedGet();
|
| 466 |
+
#else
|
| 467 |
+
return PyThreadState_GetUnchecked();
|
| 468 |
+
#endif
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
// Forward declarations
|
| 472 |
+
void keep_alive_impl(handle nurse, handle patient);
|
| 473 |
+
inline PyObject *make_new_instance(PyTypeObject *type);
|
| 474 |
+
|
| 475 |
+
class type_caster_generic {
|
| 476 |
+
public:
|
| 477 |
+
PYBIND11_NOINLINE explicit type_caster_generic(const std::type_info &type_info)
|
| 478 |
+
: typeinfo(get_type_info(type_info)), cpptype(&type_info) {}
|
| 479 |
+
|
| 480 |
+
explicit type_caster_generic(const type_info *typeinfo)
|
| 481 |
+
: typeinfo(typeinfo), cpptype(typeinfo ? typeinfo->cpptype : nullptr) {}
|
| 482 |
+
|
| 483 |
+
bool load(handle src, bool convert) { return load_impl<type_caster_generic>(src, convert); }
|
| 484 |
+
|
| 485 |
+
PYBIND11_NOINLINE static handle cast(const void *_src,
|
| 486 |
+
return_value_policy policy,
|
| 487 |
+
handle parent,
|
| 488 |
+
const detail::type_info *tinfo,
|
| 489 |
+
void *(*copy_constructor)(const void *),
|
| 490 |
+
void *(*move_constructor)(const void *),
|
| 491 |
+
const void *existing_holder = nullptr) {
|
| 492 |
+
if (!tinfo) { // no type info: error will be set already
|
| 493 |
+
return handle();
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
void *src = const_cast<void *>(_src);
|
| 497 |
+
if (src == nullptr) {
|
| 498 |
+
return none().release();
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
if (handle registered_inst = find_registered_python_instance(src, tinfo)) {
|
| 502 |
+
return registered_inst;
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
auto inst = reinterpret_steal<object>(make_new_instance(tinfo->type));
|
| 506 |
+
auto *wrapper = reinterpret_cast<instance *>(inst.ptr());
|
| 507 |
+
wrapper->owned = false;
|
| 508 |
+
void *&valueptr = values_and_holders(wrapper).begin()->value_ptr();
|
| 509 |
+
|
| 510 |
+
switch (policy) {
|
| 511 |
+
case return_value_policy::automatic:
|
| 512 |
+
case return_value_policy::take_ownership:
|
| 513 |
+
valueptr = src;
|
| 514 |
+
wrapper->owned = true;
|
| 515 |
+
break;
|
| 516 |
+
|
| 517 |
+
case return_value_policy::automatic_reference:
|
| 518 |
+
case return_value_policy::reference:
|
| 519 |
+
valueptr = src;
|
| 520 |
+
wrapper->owned = false;
|
| 521 |
+
break;
|
| 522 |
+
|
| 523 |
+
case return_value_policy::copy:
|
| 524 |
+
if (copy_constructor) {
|
| 525 |
+
valueptr = copy_constructor(src);
|
| 526 |
+
} else {
|
| 527 |
+
#if defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 528 |
+
std::string type_name(tinfo->cpptype->name());
|
| 529 |
+
detail::clean_type_id(type_name);
|
| 530 |
+
throw cast_error("return_value_policy = copy, but type " + type_name
|
| 531 |
+
+ " is non-copyable!");
|
| 532 |
+
#else
|
| 533 |
+
throw cast_error("return_value_policy = copy, but type is "
|
| 534 |
+
"non-copyable! (#define PYBIND11_DETAILED_ERROR_MESSAGES or "
|
| 535 |
+
"compile in debug mode for details)");
|
| 536 |
+
#endif
|
| 537 |
+
}
|
| 538 |
+
wrapper->owned = true;
|
| 539 |
+
break;
|
| 540 |
+
|
| 541 |
+
case return_value_policy::move:
|
| 542 |
+
if (move_constructor) {
|
| 543 |
+
valueptr = move_constructor(src);
|
| 544 |
+
} else if (copy_constructor) {
|
| 545 |
+
valueptr = copy_constructor(src);
|
| 546 |
+
} else {
|
| 547 |
+
#if defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 548 |
+
std::string type_name(tinfo->cpptype->name());
|
| 549 |
+
detail::clean_type_id(type_name);
|
| 550 |
+
throw cast_error("return_value_policy = move, but type " + type_name
|
| 551 |
+
+ " is neither movable nor copyable!");
|
| 552 |
+
#else
|
| 553 |
+
throw cast_error("return_value_policy = move, but type is neither "
|
| 554 |
+
"movable nor copyable! "
|
| 555 |
+
"(#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in "
|
| 556 |
+
"debug mode for details)");
|
| 557 |
+
#endif
|
| 558 |
+
}
|
| 559 |
+
wrapper->owned = true;
|
| 560 |
+
break;
|
| 561 |
+
|
| 562 |
+
case return_value_policy::reference_internal:
|
| 563 |
+
valueptr = src;
|
| 564 |
+
wrapper->owned = false;
|
| 565 |
+
keep_alive_impl(inst, parent);
|
| 566 |
+
break;
|
| 567 |
+
|
| 568 |
+
default:
|
| 569 |
+
throw cast_error("unhandled return_value_policy: should not happen!");
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
tinfo->init_instance(wrapper, existing_holder);
|
| 573 |
+
|
| 574 |
+
return inst.release();
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
// Base methods for generic caster; there are overridden in copyable_holder_caster
|
| 578 |
+
void load_value(value_and_holder &&v_h) {
|
| 579 |
+
auto *&vptr = v_h.value_ptr();
|
| 580 |
+
// Lazy allocation for unallocated values:
|
| 581 |
+
if (vptr == nullptr) {
|
| 582 |
+
const auto *type = v_h.type ? v_h.type : typeinfo;
|
| 583 |
+
if (type->operator_new) {
|
| 584 |
+
vptr = type->operator_new(type->type_size);
|
| 585 |
+
} else {
|
| 586 |
+
#if defined(__cpp_aligned_new) && (!defined(_MSC_VER) || _MSC_VER >= 1912)
|
| 587 |
+
if (type->type_align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) {
|
| 588 |
+
vptr = ::operator new(type->type_size, std::align_val_t(type->type_align));
|
| 589 |
+
} else {
|
| 590 |
+
vptr = ::operator new(type->type_size);
|
| 591 |
+
}
|
| 592 |
+
#else
|
| 593 |
+
vptr = ::operator new(type->type_size);
|
| 594 |
+
#endif
|
| 595 |
+
}
|
| 596 |
+
}
|
| 597 |
+
value = vptr;
|
| 598 |
+
}
|
| 599 |
+
bool try_implicit_casts(handle src, bool convert) {
|
| 600 |
+
for (const auto &cast : typeinfo->implicit_casts) {
|
| 601 |
+
type_caster_generic sub_caster(*cast.first);
|
| 602 |
+
if (sub_caster.load(src, convert)) {
|
| 603 |
+
value = cast.second(sub_caster.value);
|
| 604 |
+
return true;
|
| 605 |
+
}
|
| 606 |
+
}
|
| 607 |
+
return false;
|
| 608 |
+
}
|
| 609 |
+
bool try_direct_conversions(handle src) {
|
| 610 |
+
for (auto &converter : *typeinfo->direct_conversions) {
|
| 611 |
+
if (converter(src.ptr(), value)) {
|
| 612 |
+
return true;
|
| 613 |
+
}
|
| 614 |
+
}
|
| 615 |
+
return false;
|
| 616 |
+
}
|
| 617 |
+
bool try_cpp_conduit(handle src) {
|
| 618 |
+
value = try_raw_pointer_ephemeral_from_cpp_conduit(src, cpptype);
|
| 619 |
+
if (value != nullptr) {
|
| 620 |
+
return true;
|
| 621 |
+
}
|
| 622 |
+
return false;
|
| 623 |
+
}
|
| 624 |
+
void check_holder_compat() {}
|
| 625 |
+
|
| 626 |
+
PYBIND11_NOINLINE static void *local_load(PyObject *src, const type_info *ti) {
|
| 627 |
+
auto caster = type_caster_generic(ti);
|
| 628 |
+
if (caster.load(src, false)) {
|
| 629 |
+
return caster.value;
|
| 630 |
+
}
|
| 631 |
+
return nullptr;
|
| 632 |
+
}
|
| 633 |
+
|
| 634 |
+
/// Try to load with foreign typeinfo, if available. Used when there is no
|
| 635 |
+
/// native typeinfo, or when the native one wasn't able to produce a value.
|
| 636 |
+
PYBIND11_NOINLINE bool try_load_foreign_module_local(handle src) {
|
| 637 |
+
constexpr auto *local_key = PYBIND11_MODULE_LOCAL_ID;
|
| 638 |
+
const auto pytype = type::handle_of(src);
|
| 639 |
+
if (!hasattr(pytype, local_key)) {
|
| 640 |
+
return false;
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
type_info *foreign_typeinfo = reinterpret_borrow<capsule>(getattr(pytype, local_key));
|
| 644 |
+
// Only consider this foreign loader if actually foreign and is a loader of the correct cpp
|
| 645 |
+
// type
|
| 646 |
+
if (foreign_typeinfo->module_local_load == &local_load
|
| 647 |
+
|| (cpptype && !same_type(*cpptype, *foreign_typeinfo->cpptype))) {
|
| 648 |
+
return false;
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
if (auto *result = foreign_typeinfo->module_local_load(src.ptr(), foreign_typeinfo)) {
|
| 652 |
+
value = result;
|
| 653 |
+
return true;
|
| 654 |
+
}
|
| 655 |
+
return false;
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
// Implementation of `load`; this takes the type of `this` so that it can dispatch the relevant
|
| 659 |
+
// bits of code between here and copyable_holder_caster where the two classes need different
|
| 660 |
+
// logic (without having to resort to virtual inheritance).
|
| 661 |
+
template <typename ThisT>
|
| 662 |
+
PYBIND11_NOINLINE bool load_impl(handle src, bool convert) {
|
| 663 |
+
if (!src) {
|
| 664 |
+
return false;
|
| 665 |
+
}
|
| 666 |
+
if (!typeinfo) {
|
| 667 |
+
return try_load_foreign_module_local(src);
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
auto &this_ = static_cast<ThisT &>(*this);
|
| 671 |
+
this_.check_holder_compat();
|
| 672 |
+
|
| 673 |
+
PyTypeObject *srctype = Py_TYPE(src.ptr());
|
| 674 |
+
|
| 675 |
+
// Case 1: If src is an exact type match for the target type then we can reinterpret_cast
|
| 676 |
+
// the instance's value pointer to the target type:
|
| 677 |
+
if (srctype == typeinfo->type) {
|
| 678 |
+
this_.load_value(reinterpret_cast<instance *>(src.ptr())->get_value_and_holder());
|
| 679 |
+
return true;
|
| 680 |
+
}
|
| 681 |
+
// Case 2: We have a derived class
|
| 682 |
+
if (PyType_IsSubtype(srctype, typeinfo->type)) {
|
| 683 |
+
const auto &bases = all_type_info(srctype);
|
| 684 |
+
bool no_cpp_mi = typeinfo->simple_type;
|
| 685 |
+
|
| 686 |
+
// Case 2a: the python type is a Python-inherited derived class that inherits from just
|
| 687 |
+
// one simple (no MI) pybind11 class, or is an exact match, so the C++ instance is of
|
| 688 |
+
// the right type and we can use reinterpret_cast.
|
| 689 |
+
// (This is essentially the same as case 2b, but because not using multiple inheritance
|
| 690 |
+
// is extremely common, we handle it specially to avoid the loop iterator and type
|
| 691 |
+
// pointer lookup overhead)
|
| 692 |
+
if (bases.size() == 1 && (no_cpp_mi || bases.front()->type == typeinfo->type)) {
|
| 693 |
+
this_.load_value(reinterpret_cast<instance *>(src.ptr())->get_value_and_holder());
|
| 694 |
+
return true;
|
| 695 |
+
}
|
| 696 |
+
// Case 2b: the python type inherits from multiple C++ bases. Check the bases to see
|
| 697 |
+
// if we can find an exact match (or, for a simple C++ type, an inherited match); if
|
| 698 |
+
// so, we can safely reinterpret_cast to the relevant pointer.
|
| 699 |
+
if (bases.size() > 1) {
|
| 700 |
+
for (auto *base : bases) {
|
| 701 |
+
if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type)
|
| 702 |
+
: base->type == typeinfo->type) {
|
| 703 |
+
this_.load_value(
|
| 704 |
+
reinterpret_cast<instance *>(src.ptr())->get_value_and_holder(base));
|
| 705 |
+
return true;
|
| 706 |
+
}
|
| 707 |
+
}
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
// Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type
|
| 711 |
+
// match in the registered bases, above, so try implicit casting (needed for proper C++
|
| 712 |
+
// casting when MI is involved).
|
| 713 |
+
if (this_.try_implicit_casts(src, convert)) {
|
| 714 |
+
return true;
|
| 715 |
+
}
|
| 716 |
+
}
|
| 717 |
+
|
| 718 |
+
// Perform an implicit conversion
|
| 719 |
+
if (convert) {
|
| 720 |
+
for (const auto &converter : typeinfo->implicit_conversions) {
|
| 721 |
+
auto temp = reinterpret_steal<object>(converter(src.ptr(), typeinfo->type));
|
| 722 |
+
if (load_impl<ThisT>(temp, false)) {
|
| 723 |
+
loader_life_support::add_patient(temp);
|
| 724 |
+
return true;
|
| 725 |
+
}
|
| 726 |
+
}
|
| 727 |
+
if (this_.try_direct_conversions(src)) {
|
| 728 |
+
return true;
|
| 729 |
+
}
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
// Failed to match local typeinfo. Try again with global.
|
| 733 |
+
if (typeinfo->module_local) {
|
| 734 |
+
if (auto *gtype = get_global_type_info(*typeinfo->cpptype)) {
|
| 735 |
+
typeinfo = gtype;
|
| 736 |
+
return load(src, false);
|
| 737 |
+
}
|
| 738 |
+
}
|
| 739 |
+
|
| 740 |
+
// Global typeinfo has precedence over foreign module_local
|
| 741 |
+
if (try_load_foreign_module_local(src)) {
|
| 742 |
+
return true;
|
| 743 |
+
}
|
| 744 |
+
|
| 745 |
+
// Custom converters didn't take None, now we convert None to nullptr.
|
| 746 |
+
if (src.is_none()) {
|
| 747 |
+
// Defer accepting None to other overloads (if we aren't in convert mode):
|
| 748 |
+
if (!convert) {
|
| 749 |
+
return false;
|
| 750 |
+
}
|
| 751 |
+
value = nullptr;
|
| 752 |
+
return true;
|
| 753 |
+
}
|
| 754 |
+
|
| 755 |
+
if (convert && cpptype && this_.try_cpp_conduit(src)) {
|
| 756 |
+
return true;
|
| 757 |
+
}
|
| 758 |
+
|
| 759 |
+
return false;
|
| 760 |
+
}
|
| 761 |
+
|
| 762 |
+
// Called to do type lookup and wrap the pointer and type in a pair when a dynamic_cast
|
| 763 |
+
// isn't needed or can't be used. If the type is unknown, sets the error and returns a pair
|
| 764 |
+
// with .second = nullptr. (p.first = nullptr is not an error: it becomes None).
|
| 765 |
+
PYBIND11_NOINLINE static std::pair<const void *, const type_info *>
|
| 766 |
+
src_and_type(const void *src,
|
| 767 |
+
const std::type_info &cast_type,
|
| 768 |
+
const std::type_info *rtti_type = nullptr) {
|
| 769 |
+
if (auto *tpi = get_type_info(cast_type)) {
|
| 770 |
+
return {src, const_cast<const type_info *>(tpi)};
|
| 771 |
+
}
|
| 772 |
+
|
| 773 |
+
// Not found, set error:
|
| 774 |
+
std::string tname = rtti_type ? rtti_type->name() : cast_type.name();
|
| 775 |
+
detail::clean_type_id(tname);
|
| 776 |
+
std::string msg = "Unregistered type : " + tname;
|
| 777 |
+
set_error(PyExc_TypeError, msg.c_str());
|
| 778 |
+
return {nullptr, nullptr};
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
const type_info *typeinfo = nullptr;
|
| 782 |
+
const std::type_info *cpptype = nullptr;
|
| 783 |
+
void *value = nullptr;
|
| 784 |
+
};
|
| 785 |
+
|
| 786 |
+
inline object cpp_conduit_method(handle self,
|
| 787 |
+
const bytes &pybind11_platform_abi_id,
|
| 788 |
+
const capsule &cpp_type_info_capsule,
|
| 789 |
+
const bytes &pointer_kind) {
|
| 790 |
+
#ifdef PYBIND11_HAS_STRING_VIEW
|
| 791 |
+
using cpp_str = std::string_view;
|
| 792 |
+
#else
|
| 793 |
+
using cpp_str = std::string;
|
| 794 |
+
#endif
|
| 795 |
+
if (cpp_str(pybind11_platform_abi_id) != PYBIND11_PLATFORM_ABI_ID) {
|
| 796 |
+
return none();
|
| 797 |
+
}
|
| 798 |
+
if (std::strcmp(cpp_type_info_capsule.name(), typeid(std::type_info).name()) != 0) {
|
| 799 |
+
return none();
|
| 800 |
+
}
|
| 801 |
+
if (cpp_str(pointer_kind) != "raw_pointer_ephemeral") {
|
| 802 |
+
throw std::runtime_error("Invalid pointer_kind: \"" + std::string(pointer_kind) + "\"");
|
| 803 |
+
}
|
| 804 |
+
const auto *cpp_type_info = cpp_type_info_capsule.get_pointer<const std::type_info>();
|
| 805 |
+
type_caster_generic caster(*cpp_type_info);
|
| 806 |
+
if (!caster.load(self, false)) {
|
| 807 |
+
return none();
|
| 808 |
+
}
|
| 809 |
+
return capsule(caster.value, cpp_type_info->name());
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
/**
|
| 813 |
+
* Determine suitable casting operator for pointer-or-lvalue-casting type casters. The type caster
|
| 814 |
+
* needs to provide `operator T*()` and `operator T&()` operators.
|
| 815 |
+
*
|
| 816 |
+
* If the type supports moving the value away via an `operator T&&() &&` method, it should use
|
| 817 |
+
* `movable_cast_op_type` instead.
|
| 818 |
+
*/
|
| 819 |
+
template <typename T>
|
| 820 |
+
using cast_op_type = conditional_t<std::is_pointer<remove_reference_t<T>>::value,
|
| 821 |
+
typename std::add_pointer<intrinsic_t<T>>::type,
|
| 822 |
+
typename std::add_lvalue_reference<intrinsic_t<T>>::type>;
|
| 823 |
+
|
| 824 |
+
/**
|
| 825 |
+
* Determine suitable casting operator for a type caster with a movable value. Such a type caster
|
| 826 |
+
* needs to provide `operator T*()`, `operator T&()`, and `operator T&&() &&`. The latter will be
|
| 827 |
+
* called in appropriate contexts where the value can be moved rather than copied.
|
| 828 |
+
*
|
| 829 |
+
* These operator are automatically provided when using the PYBIND11_TYPE_CASTER macro.
|
| 830 |
+
*/
|
| 831 |
+
template <typename T>
|
| 832 |
+
using movable_cast_op_type
|
| 833 |
+
= conditional_t<std::is_pointer<typename std::remove_reference<T>::type>::value,
|
| 834 |
+
typename std::add_pointer<intrinsic_t<T>>::type,
|
| 835 |
+
conditional_t<std::is_rvalue_reference<T>::value,
|
| 836 |
+
typename std::add_rvalue_reference<intrinsic_t<T>>::type,
|
| 837 |
+
typename std::add_lvalue_reference<intrinsic_t<T>>::type>>;
|
| 838 |
+
|
| 839 |
+
// Does the container have a mapped type and is it recursive?
|
| 840 |
+
// Implemented by specializations below.
|
| 841 |
+
template <typename Container, typename SFINAE = void>
|
| 842 |
+
struct container_mapped_type_traits {
|
| 843 |
+
static constexpr bool has_mapped_type = false;
|
| 844 |
+
static constexpr bool has_recursive_mapped_type = false;
|
| 845 |
+
};
|
| 846 |
+
|
| 847 |
+
template <typename Container>
|
| 848 |
+
struct container_mapped_type_traits<
|
| 849 |
+
Container,
|
| 850 |
+
typename std::enable_if<
|
| 851 |
+
std::is_same<typename Container::mapped_type, Container>::value>::type> {
|
| 852 |
+
static constexpr bool has_mapped_type = true;
|
| 853 |
+
static constexpr bool has_recursive_mapped_type = true;
|
| 854 |
+
};
|
| 855 |
+
|
| 856 |
+
template <typename Container>
|
| 857 |
+
struct container_mapped_type_traits<
|
| 858 |
+
Container,
|
| 859 |
+
typename std::enable_if<
|
| 860 |
+
negation<std::is_same<typename Container::mapped_type, Container>>::value>::type> {
|
| 861 |
+
static constexpr bool has_mapped_type = true;
|
| 862 |
+
static constexpr bool has_recursive_mapped_type = false;
|
| 863 |
+
};
|
| 864 |
+
|
| 865 |
+
// Does the container have a value type and is it recursive?
|
| 866 |
+
// Implemented by specializations below.
|
| 867 |
+
template <typename Container, typename SFINAE = void>
|
| 868 |
+
struct container_value_type_traits : std::false_type {
|
| 869 |
+
static constexpr bool has_value_type = false;
|
| 870 |
+
static constexpr bool has_recursive_value_type = false;
|
| 871 |
+
};
|
| 872 |
+
|
| 873 |
+
template <typename Container>
|
| 874 |
+
struct container_value_type_traits<
|
| 875 |
+
Container,
|
| 876 |
+
typename std::enable_if<
|
| 877 |
+
std::is_same<typename Container::value_type, Container>::value>::type> {
|
| 878 |
+
static constexpr bool has_value_type = true;
|
| 879 |
+
static constexpr bool has_recursive_value_type = true;
|
| 880 |
+
};
|
| 881 |
+
|
| 882 |
+
template <typename Container>
|
| 883 |
+
struct container_value_type_traits<
|
| 884 |
+
Container,
|
| 885 |
+
typename std::enable_if<
|
| 886 |
+
negation<std::is_same<typename Container::value_type, Container>>::value>::type> {
|
| 887 |
+
static constexpr bool has_value_type = true;
|
| 888 |
+
static constexpr bool has_recursive_value_type = false;
|
| 889 |
+
};
|
| 890 |
+
|
| 891 |
+
/*
|
| 892 |
+
* Tag to be used for representing the bottom of recursively defined types.
|
| 893 |
+
* Define this tag so we don't have to use void.
|
| 894 |
+
*/
|
| 895 |
+
struct recursive_bottom {};
|
| 896 |
+
|
| 897 |
+
/*
|
| 898 |
+
* Implementation detail of `recursive_container_traits` below.
|
| 899 |
+
* `T` is the `value_type` of the container, which might need to be modified to
|
| 900 |
+
* avoid recursive types and const types.
|
| 901 |
+
*/
|
| 902 |
+
template <typename T, bool is_this_a_map>
|
| 903 |
+
struct impl_type_to_check_recursively {
|
| 904 |
+
/*
|
| 905 |
+
* If the container is recursive, then no further recursion should be done.
|
| 906 |
+
*/
|
| 907 |
+
using if_recursive = recursive_bottom;
|
| 908 |
+
/*
|
| 909 |
+
* Otherwise yield `T` unchanged.
|
| 910 |
+
*/
|
| 911 |
+
using if_not_recursive = T;
|
| 912 |
+
};
|
| 913 |
+
|
| 914 |
+
/*
|
| 915 |
+
* For pairs - only as value type of a map -, the first type should remove the `const`.
|
| 916 |
+
* Also, if the map is recursive, then the recursive checking should consider
|
| 917 |
+
* the first type only.
|
| 918 |
+
*/
|
| 919 |
+
template <typename A, typename B>
|
| 920 |
+
struct impl_type_to_check_recursively<std::pair<A, B>, /* is_this_a_map = */ true> {
|
| 921 |
+
using if_recursive = typename std::remove_const<A>::type;
|
| 922 |
+
using if_not_recursive = std::pair<typename std::remove_const<A>::type, B>;
|
| 923 |
+
};
|
| 924 |
+
|
| 925 |
+
/*
|
| 926 |
+
* Implementation of `recursive_container_traits` below.
|
| 927 |
+
*/
|
| 928 |
+
template <typename Container, typename SFINAE = void>
|
| 929 |
+
struct impl_recursive_container_traits {
|
| 930 |
+
using type_to_check_recursively = recursive_bottom;
|
| 931 |
+
};
|
| 932 |
+
|
| 933 |
+
template <typename Container>
|
| 934 |
+
struct impl_recursive_container_traits<
|
| 935 |
+
Container,
|
| 936 |
+
typename std::enable_if<container_value_type_traits<Container>::has_value_type>::type> {
|
| 937 |
+
static constexpr bool is_recursive
|
| 938 |
+
= container_mapped_type_traits<Container>::has_recursive_mapped_type
|
| 939 |
+
|| container_value_type_traits<Container>::has_recursive_value_type;
|
| 940 |
+
/*
|
| 941 |
+
* This member dictates which type Pybind11 should check recursively in traits
|
| 942 |
+
* such as `is_move_constructible`, `is_copy_constructible`, `is_move_assignable`, ...
|
| 943 |
+
* Direct access to `value_type` should be avoided:
|
| 944 |
+
* 1. `value_type` might recursively contain the type again
|
| 945 |
+
* 2. `value_type` of STL map types is `std::pair<A const, B>`, the `const`
|
| 946 |
+
* should be removed.
|
| 947 |
+
*
|
| 948 |
+
*/
|
| 949 |
+
using type_to_check_recursively = typename std::conditional<
|
| 950 |
+
is_recursive,
|
| 951 |
+
typename impl_type_to_check_recursively<
|
| 952 |
+
typename Container::value_type,
|
| 953 |
+
container_mapped_type_traits<Container>::has_mapped_type>::if_recursive,
|
| 954 |
+
typename impl_type_to_check_recursively<
|
| 955 |
+
typename Container::value_type,
|
| 956 |
+
container_mapped_type_traits<Container>::has_mapped_type>::if_not_recursive>::type;
|
| 957 |
+
};
|
| 958 |
+
|
| 959 |
+
/*
|
| 960 |
+
* This trait defines the `type_to_check_recursively` which is needed to properly
|
| 961 |
+
* handle recursively defined traits such as `is_move_constructible` without going
|
| 962 |
+
* into an infinite recursion.
|
| 963 |
+
* Should be used instead of directly accessing the `value_type`.
|
| 964 |
+
* It cancels the recursion by returning the `recursive_bottom` tag.
|
| 965 |
+
*
|
| 966 |
+
* The default definition of `type_to_check_recursively` is as follows:
|
| 967 |
+
*
|
| 968 |
+
* 1. By default, it is `recursive_bottom`, so that the recursion is canceled.
|
| 969 |
+
* 2. If the type is non-recursive and defines a `value_type`, then the `value_type` is used.
|
| 970 |
+
* If the `value_type` is a pair and a `mapped_type` is defined,
|
| 971 |
+
* then the `const` is removed from the first type.
|
| 972 |
+
* 3. If the type is recursive and `value_type` is not a pair, then `recursive_bottom` is returned.
|
| 973 |
+
* 4. If the type is recursive and `value_type` is a pair and a `mapped_type` is defined,
|
| 974 |
+
* then `const` is removed from the first type and the first type is returned.
|
| 975 |
+
*
|
| 976 |
+
* This behavior can be extended by the user as seen in test_stl_binders.cpp.
|
| 977 |
+
*
|
| 978 |
+
* This struct is exactly the same as impl_recursive_container_traits.
|
| 979 |
+
* The duplication achieves that user-defined specializations don't compete
|
| 980 |
+
* with internal specializations, but take precedence.
|
| 981 |
+
*/
|
| 982 |
+
template <typename Container, typename SFINAE = void>
|
| 983 |
+
struct recursive_container_traits : impl_recursive_container_traits<Container> {};
|
| 984 |
+
|
| 985 |
+
template <typename T>
|
| 986 |
+
struct is_move_constructible
|
| 987 |
+
: all_of<std::is_move_constructible<T>,
|
| 988 |
+
is_move_constructible<
|
| 989 |
+
typename recursive_container_traits<T>::type_to_check_recursively>> {};
|
| 990 |
+
|
| 991 |
+
template <>
|
| 992 |
+
struct is_move_constructible<recursive_bottom> : std::true_type {};
|
| 993 |
+
|
| 994 |
+
// Likewise for std::pair
|
| 995 |
+
// (after C++17 it is mandatory that the move constructor not exist when the two types aren't
|
| 996 |
+
// themselves move constructible, but this can not be relied upon when T1 or T2 are themselves
|
| 997 |
+
// containers).
|
| 998 |
+
template <typename T1, typename T2>
|
| 999 |
+
struct is_move_constructible<std::pair<T1, T2>>
|
| 1000 |
+
: all_of<is_move_constructible<T1>, is_move_constructible<T2>> {};
|
| 1001 |
+
|
| 1002 |
+
// std::is_copy_constructible isn't quite enough: it lets std::vector<T> (and similar) through when
|
| 1003 |
+
// T is non-copyable, but code containing such a copy constructor fails to actually compile.
|
| 1004 |
+
template <typename T>
|
| 1005 |
+
struct is_copy_constructible
|
| 1006 |
+
: all_of<std::is_copy_constructible<T>,
|
| 1007 |
+
is_copy_constructible<
|
| 1008 |
+
typename recursive_container_traits<T>::type_to_check_recursively>> {};
|
| 1009 |
+
|
| 1010 |
+
template <>
|
| 1011 |
+
struct is_copy_constructible<recursive_bottom> : std::true_type {};
|
| 1012 |
+
|
| 1013 |
+
// Likewise for std::pair
|
| 1014 |
+
// (after C++17 it is mandatory that the copy constructor not exist when the two types aren't
|
| 1015 |
+
// themselves copy constructible, but this can not be relied upon when T1 or T2 are themselves
|
| 1016 |
+
// containers).
|
| 1017 |
+
template <typename T1, typename T2>
|
| 1018 |
+
struct is_copy_constructible<std::pair<T1, T2>>
|
| 1019 |
+
: all_of<is_copy_constructible<T1>, is_copy_constructible<T2>> {};
|
| 1020 |
+
|
| 1021 |
+
// The same problems arise with std::is_copy_assignable, so we use the same workaround.
|
| 1022 |
+
template <typename T>
|
| 1023 |
+
struct is_copy_assignable
|
| 1024 |
+
: all_of<
|
| 1025 |
+
std::is_copy_assignable<T>,
|
| 1026 |
+
is_copy_assignable<typename recursive_container_traits<T>::type_to_check_recursively>> {
|
| 1027 |
+
};
|
| 1028 |
+
|
| 1029 |
+
template <>
|
| 1030 |
+
struct is_copy_assignable<recursive_bottom> : std::true_type {};
|
| 1031 |
+
|
| 1032 |
+
template <typename T1, typename T2>
|
| 1033 |
+
struct is_copy_assignable<std::pair<T1, T2>>
|
| 1034 |
+
: all_of<is_copy_assignable<T1>, is_copy_assignable<T2>> {};
|
| 1035 |
+
|
| 1036 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 1037 |
+
|
| 1038 |
+
// polymorphic_type_hook<itype>::get(src, tinfo) determines whether the object pointed
|
| 1039 |
+
// to by `src` actually is an instance of some class derived from `itype`.
|
| 1040 |
+
// If so, it sets `tinfo` to point to the std::type_info representing that derived
|
| 1041 |
+
// type, and returns a pointer to the start of the most-derived object of that type
|
| 1042 |
+
// (in which `src` is a subobject; this will be the same address as `src` in most
|
| 1043 |
+
// single inheritance cases). If not, or if `src` is nullptr, it simply returns `src`
|
| 1044 |
+
// and leaves `tinfo` at its default value of nullptr.
|
| 1045 |
+
//
|
| 1046 |
+
// The default polymorphic_type_hook just returns src. A specialization for polymorphic
|
| 1047 |
+
// types determines the runtime type of the passed object and adjusts the this-pointer
|
| 1048 |
+
// appropriately via dynamic_cast<void*>. This is what enables a C++ Animal* to appear
|
| 1049 |
+
// to Python as a Dog (if Dog inherits from Animal, Animal is polymorphic, Dog is
|
| 1050 |
+
// registered with pybind11, and this Animal is in fact a Dog).
|
| 1051 |
+
//
|
| 1052 |
+
// You may specialize polymorphic_type_hook yourself for types that want to appear
|
| 1053 |
+
// polymorphic to Python but do not use C++ RTTI. (This is a not uncommon pattern
|
| 1054 |
+
// in performance-sensitive applications, used most notably in LLVM.)
|
| 1055 |
+
//
|
| 1056 |
+
// polymorphic_type_hook_base allows users to specialize polymorphic_type_hook with
|
| 1057 |
+
// std::enable_if. User provided specializations will always have higher priority than
|
| 1058 |
+
// the default implementation and specialization provided in polymorphic_type_hook_base.
|
| 1059 |
+
template <typename itype, typename SFINAE = void>
|
| 1060 |
+
struct polymorphic_type_hook_base {
|
| 1061 |
+
static const void *get(const itype *src, const std::type_info *&) { return src; }
|
| 1062 |
+
};
|
| 1063 |
+
template <typename itype>
|
| 1064 |
+
struct polymorphic_type_hook_base<itype, detail::enable_if_t<std::is_polymorphic<itype>::value>> {
|
| 1065 |
+
static const void *get(const itype *src, const std::type_info *&type) {
|
| 1066 |
+
type = src ? &typeid(*src) : nullptr;
|
| 1067 |
+
return dynamic_cast<const void *>(src);
|
| 1068 |
+
}
|
| 1069 |
+
};
|
| 1070 |
+
template <typename itype, typename SFINAE = void>
|
| 1071 |
+
struct polymorphic_type_hook : public polymorphic_type_hook_base<itype> {};
|
| 1072 |
+
|
| 1073 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 1074 |
+
|
| 1075 |
+
/// Generic type caster for objects stored on the heap
|
| 1076 |
+
template <typename type>
|
| 1077 |
+
class type_caster_base : public type_caster_generic {
|
| 1078 |
+
using itype = intrinsic_t<type>;
|
| 1079 |
+
|
| 1080 |
+
public:
|
| 1081 |
+
static constexpr auto name = const_name<type>();
|
| 1082 |
+
|
| 1083 |
+
type_caster_base() : type_caster_base(typeid(type)) {}
|
| 1084 |
+
explicit type_caster_base(const std::type_info &info) : type_caster_generic(info) {}
|
| 1085 |
+
|
| 1086 |
+
static handle cast(const itype &src, return_value_policy policy, handle parent) {
|
| 1087 |
+
if (policy == return_value_policy::automatic
|
| 1088 |
+
|| policy == return_value_policy::automatic_reference) {
|
| 1089 |
+
policy = return_value_policy::copy;
|
| 1090 |
+
}
|
| 1091 |
+
return cast(std::addressof(src), policy, parent);
|
| 1092 |
+
}
|
| 1093 |
+
|
| 1094 |
+
static handle cast(itype &&src, return_value_policy, handle parent) {
|
| 1095 |
+
return cast(std::addressof(src), return_value_policy::move, parent);
|
| 1096 |
+
}
|
| 1097 |
+
|
| 1098 |
+
// Returns a (pointer, type_info) pair taking care of necessary type lookup for a
|
| 1099 |
+
// polymorphic type (using RTTI by default, but can be overridden by specializing
|
| 1100 |
+
// polymorphic_type_hook). If the instance isn't derived, returns the base version.
|
| 1101 |
+
static std::pair<const void *, const type_info *> src_and_type(const itype *src) {
|
| 1102 |
+
const auto &cast_type = typeid(itype);
|
| 1103 |
+
const std::type_info *instance_type = nullptr;
|
| 1104 |
+
const void *vsrc = polymorphic_type_hook<itype>::get(src, instance_type);
|
| 1105 |
+
if (instance_type && !same_type(cast_type, *instance_type)) {
|
| 1106 |
+
// This is a base pointer to a derived type. If the derived type is registered
|
| 1107 |
+
// with pybind11, we want to make the full derived object available.
|
| 1108 |
+
// In the typical case where itype is polymorphic, we get the correct
|
| 1109 |
+
// derived pointer (which may be != base pointer) by a dynamic_cast to
|
| 1110 |
+
// most derived type. If itype is not polymorphic, we won't get here
|
| 1111 |
+
// except via a user-provided specialization of polymorphic_type_hook,
|
| 1112 |
+
// and the user has promised that no this-pointer adjustment is
|
| 1113 |
+
// required in that case, so it's OK to use static_cast.
|
| 1114 |
+
if (const auto *tpi = get_type_info(*instance_type)) {
|
| 1115 |
+
return {vsrc, tpi};
|
| 1116 |
+
}
|
| 1117 |
+
}
|
| 1118 |
+
// Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer,
|
| 1119 |
+
// so don't do a cast
|
| 1120 |
+
return type_caster_generic::src_and_type(src, cast_type, instance_type);
|
| 1121 |
+
}
|
| 1122 |
+
|
| 1123 |
+
static handle cast(const itype *src, return_value_policy policy, handle parent) {
|
| 1124 |
+
auto st = src_and_type(src);
|
| 1125 |
+
return type_caster_generic::cast(st.first,
|
| 1126 |
+
policy,
|
| 1127 |
+
parent,
|
| 1128 |
+
st.second,
|
| 1129 |
+
make_copy_constructor(src),
|
| 1130 |
+
make_move_constructor(src));
|
| 1131 |
+
}
|
| 1132 |
+
|
| 1133 |
+
static handle cast_holder(const itype *src, const void *holder) {
|
| 1134 |
+
auto st = src_and_type(src);
|
| 1135 |
+
return type_caster_generic::cast(st.first,
|
| 1136 |
+
return_value_policy::take_ownership,
|
| 1137 |
+
{},
|
| 1138 |
+
st.second,
|
| 1139 |
+
nullptr,
|
| 1140 |
+
nullptr,
|
| 1141 |
+
holder);
|
| 1142 |
+
}
|
| 1143 |
+
|
| 1144 |
+
template <typename T>
|
| 1145 |
+
using cast_op_type = detail::cast_op_type<T>;
|
| 1146 |
+
|
| 1147 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 1148 |
+
operator itype *() { return (type *) value; }
|
| 1149 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 1150 |
+
operator itype &() {
|
| 1151 |
+
if (!value) {
|
| 1152 |
+
throw reference_cast_error();
|
| 1153 |
+
}
|
| 1154 |
+
return *((itype *) value);
|
| 1155 |
+
}
|
| 1156 |
+
|
| 1157 |
+
protected:
|
| 1158 |
+
using Constructor = void *(*) (const void *);
|
| 1159 |
+
|
| 1160 |
+
/* Only enabled when the types are {copy,move}-constructible *and* when the type
|
| 1161 |
+
does not have a private operator new implementation. A comma operator is used in the
|
| 1162 |
+
decltype argument to apply SFINAE to the public copy/move constructors.*/
|
| 1163 |
+
template <typename T, typename = enable_if_t<is_copy_constructible<T>::value>>
|
| 1164 |
+
static auto make_copy_constructor(const T *) -> decltype(new T(std::declval<const T>()),
|
| 1165 |
+
Constructor{}) {
|
| 1166 |
+
return [](const void *arg) -> void * { return new T(*reinterpret_cast<const T *>(arg)); };
|
| 1167 |
+
}
|
| 1168 |
+
|
| 1169 |
+
template <typename T, typename = enable_if_t<is_move_constructible<T>::value>>
|
| 1170 |
+
static auto make_move_constructor(const T *) -> decltype(new T(std::declval<T &&>()),
|
| 1171 |
+
Constructor{}) {
|
| 1172 |
+
return [](const void *arg) -> void * {
|
| 1173 |
+
return new T(std::move(*const_cast<T *>(reinterpret_cast<const T *>(arg))));
|
| 1174 |
+
};
|
| 1175 |
+
}
|
| 1176 |
+
|
| 1177 |
+
static Constructor make_copy_constructor(...) { return nullptr; }
|
| 1178 |
+
static Constructor make_move_constructor(...) { return nullptr; }
|
| 1179 |
+
};
|
| 1180 |
+
|
| 1181 |
+
inline std::string quote_cpp_type_name(const std::string &cpp_type_name) {
|
| 1182 |
+
return cpp_type_name; // No-op for now. See PR #4888
|
| 1183 |
+
}
|
| 1184 |
+
|
| 1185 |
+
PYBIND11_NOINLINE std::string type_info_description(const std::type_info &ti) {
|
| 1186 |
+
if (auto *type_data = get_type_info(ti)) {
|
| 1187 |
+
handle th((PyObject *) type_data->type);
|
| 1188 |
+
return th.attr("__module__").cast<std::string>() + '.'
|
| 1189 |
+
+ th.attr("__qualname__").cast<std::string>();
|
| 1190 |
+
}
|
| 1191 |
+
return quote_cpp_type_name(clean_type_id(ti.name()));
|
| 1192 |
+
}
|
| 1193 |
+
|
| 1194 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 1195 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/detail/typeid.h
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/detail/typeid.h: Compiler-independent access to type identifiers
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <cstdio>
|
| 13 |
+
#include <cstdlib>
|
| 14 |
+
|
| 15 |
+
#if defined(__GNUG__)
|
| 16 |
+
# include <cxxabi.h>
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include "common.h"
|
| 20 |
+
|
| 21 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 22 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 23 |
+
|
| 24 |
+
/// Erase all occurrences of a substring
|
| 25 |
+
inline void erase_all(std::string &string, const std::string &search) {
|
| 26 |
+
for (size_t pos = 0;;) {
|
| 27 |
+
pos = string.find(search, pos);
|
| 28 |
+
if (pos == std::string::npos) {
|
| 29 |
+
break;
|
| 30 |
+
}
|
| 31 |
+
string.erase(pos, search.length());
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
PYBIND11_NOINLINE void clean_type_id(std::string &name) {
|
| 36 |
+
#if defined(__GNUG__)
|
| 37 |
+
int status = 0;
|
| 38 |
+
std::unique_ptr<char, void (*)(void *)> res{
|
| 39 |
+
abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free};
|
| 40 |
+
if (status == 0) {
|
| 41 |
+
name = res.get();
|
| 42 |
+
}
|
| 43 |
+
#else
|
| 44 |
+
detail::erase_all(name, "class ");
|
| 45 |
+
detail::erase_all(name, "struct ");
|
| 46 |
+
detail::erase_all(name, "enum ");
|
| 47 |
+
#endif
|
| 48 |
+
detail::erase_all(name, "pybind11::");
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
inline std::string clean_type_id(const char *typeid_name) {
|
| 52 |
+
std::string name(typeid_name);
|
| 53 |
+
detail::clean_type_id(name);
|
| 54 |
+
return name;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 58 |
+
|
| 59 |
+
/// Return a string representation of a C++ type
|
| 60 |
+
template <typename T>
|
| 61 |
+
static std::string type_id() {
|
| 62 |
+
return detail::clean_type_id(typeid(T).name());
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/detail/value_and_holder.h
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2016-2024 The Pybind Development Team.
|
| 2 |
+
// All rights reserved. Use of this source code is governed by a
|
| 3 |
+
// BSD-style license that can be found in the LICENSE file.
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include "common.h"
|
| 8 |
+
|
| 9 |
+
#include <cstddef>
|
| 10 |
+
#include <typeinfo>
|
| 11 |
+
|
| 12 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 13 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 14 |
+
|
| 15 |
+
struct value_and_holder {
|
| 16 |
+
instance *inst = nullptr;
|
| 17 |
+
size_t index = 0u;
|
| 18 |
+
const detail::type_info *type = nullptr;
|
| 19 |
+
void **vh = nullptr;
|
| 20 |
+
|
| 21 |
+
// Main constructor for a found value/holder:
|
| 22 |
+
value_and_holder(instance *i, const detail::type_info *type, size_t vpos, size_t index)
|
| 23 |
+
: inst{i}, index{index}, type{type},
|
| 24 |
+
vh{inst->simple_layout ? inst->simple_value_holder
|
| 25 |
+
: &inst->nonsimple.values_and_holders[vpos]} {}
|
| 26 |
+
|
| 27 |
+
// Default constructor (used to signal a value-and-holder not found by get_value_and_holder())
|
| 28 |
+
value_and_holder() = default;
|
| 29 |
+
|
| 30 |
+
// Used for past-the-end iterator
|
| 31 |
+
explicit value_and_holder(size_t index) : index{index} {}
|
| 32 |
+
|
| 33 |
+
template <typename V = void>
|
| 34 |
+
V *&value_ptr() const {
|
| 35 |
+
return reinterpret_cast<V *&>(vh[0]);
|
| 36 |
+
}
|
| 37 |
+
// True if this `value_and_holder` has a non-null value pointer
|
| 38 |
+
explicit operator bool() const { return value_ptr() != nullptr; }
|
| 39 |
+
|
| 40 |
+
template <typename H>
|
| 41 |
+
H &holder() const {
|
| 42 |
+
return reinterpret_cast<H &>(vh[1]);
|
| 43 |
+
}
|
| 44 |
+
bool holder_constructed() const {
|
| 45 |
+
return inst->simple_layout
|
| 46 |
+
? inst->simple_holder_constructed
|
| 47 |
+
: (inst->nonsimple.status[index] & instance::status_holder_constructed) != 0u;
|
| 48 |
+
}
|
| 49 |
+
// NOLINTNEXTLINE(readability-make-member-function-const)
|
| 50 |
+
void set_holder_constructed(bool v = true) {
|
| 51 |
+
if (inst->simple_layout) {
|
| 52 |
+
inst->simple_holder_constructed = v;
|
| 53 |
+
} else if (v) {
|
| 54 |
+
inst->nonsimple.status[index] |= instance::status_holder_constructed;
|
| 55 |
+
} else {
|
| 56 |
+
inst->nonsimple.status[index] &= (std::uint8_t) ~instance::status_holder_constructed;
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
bool instance_registered() const {
|
| 60 |
+
return inst->simple_layout
|
| 61 |
+
? inst->simple_instance_registered
|
| 62 |
+
: ((inst->nonsimple.status[index] & instance::status_instance_registered) != 0);
|
| 63 |
+
}
|
| 64 |
+
// NOLINTNEXTLINE(readability-make-member-function-const)
|
| 65 |
+
void set_instance_registered(bool v = true) {
|
| 66 |
+
if (inst->simple_layout) {
|
| 67 |
+
inst->simple_instance_registered = v;
|
| 68 |
+
} else if (v) {
|
| 69 |
+
inst->nonsimple.status[index] |= instance::status_instance_registered;
|
| 70 |
+
} else {
|
| 71 |
+
inst->nonsimple.status[index] &= (std::uint8_t) ~instance::status_instance_registered;
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
};
|
| 75 |
+
|
| 76 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 77 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/eigen.h
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "eigen/matrix.h"
|
phivenv/Lib/site-packages/torch/include/pybind11/eigen/common.h
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 The pybind Community.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
// Common message for `static_assert()`s, which are useful to easily
|
| 6 |
+
// preempt much less obvious errors.
|
| 7 |
+
#define PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED \
|
| 8 |
+
"Pointer types (in particular `PyObject *`) are not supported as scalar types for Eigen " \
|
| 9 |
+
"types."
|
phivenv/Lib/site-packages/torch/include/pybind11/eigen/matrix.h
ADDED
|
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/eigen/matrix.h: Transparent conversion for dense and sparse Eigen matrices
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <pybind11/numpy.h>
|
| 13 |
+
|
| 14 |
+
#include "common.h"
|
| 15 |
+
|
| 16 |
+
/* HINT: To suppress warnings originating from the Eigen headers, use -isystem.
|
| 17 |
+
See also:
|
| 18 |
+
https://stackoverflow.com/questions/2579576/i-dir-vs-isystem-dir
|
| 19 |
+
https://stackoverflow.com/questions/1741816/isystem-for-ms-visual-studio-c-compiler
|
| 20 |
+
*/
|
| 21 |
+
PYBIND11_WARNING_PUSH
|
| 22 |
+
PYBIND11_WARNING_DISABLE_MSVC(5054) // https://github.com/pybind/pybind11/pull/3741
|
| 23 |
+
// C5054: operator '&': deprecated between enumerations of different types
|
| 24 |
+
#if defined(__MINGW32__)
|
| 25 |
+
PYBIND11_WARNING_DISABLE_GCC("-Wmaybe-uninitialized")
|
| 26 |
+
#endif
|
| 27 |
+
|
| 28 |
+
#include <Eigen/Core>
|
| 29 |
+
#include <Eigen/SparseCore>
|
| 30 |
+
|
| 31 |
+
PYBIND11_WARNING_POP
|
| 32 |
+
|
| 33 |
+
// Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit
|
| 34 |
+
// move constructors that break things. We could detect this an explicitly copy, but an extra copy
|
| 35 |
+
// of matrices seems highly undesirable.
|
| 36 |
+
static_assert(EIGEN_VERSION_AT_LEAST(3, 2, 7),
|
| 37 |
+
"Eigen matrix support in pybind11 requires Eigen >= 3.2.7");
|
| 38 |
+
|
| 39 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 40 |
+
|
| 41 |
+
PYBIND11_WARNING_DISABLE_MSVC(4127)
|
| 42 |
+
|
| 43 |
+
// Provide a convenience alias for easier pass-by-ref usage with fully dynamic strides:
|
| 44 |
+
using EigenDStride = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
|
| 45 |
+
template <typename MatrixType>
|
| 46 |
+
using EigenDRef = Eigen::Ref<MatrixType, 0, EigenDStride>;
|
| 47 |
+
template <typename MatrixType>
|
| 48 |
+
using EigenDMap = Eigen::Map<MatrixType, 0, EigenDStride>;
|
| 49 |
+
|
| 50 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 51 |
+
|
| 52 |
+
#if EIGEN_VERSION_AT_LEAST(3, 3, 0)
|
| 53 |
+
using EigenIndex = Eigen::Index;
|
| 54 |
+
template <typename Scalar, int Flags, typename StorageIndex>
|
| 55 |
+
using EigenMapSparseMatrix = Eigen::Map<Eigen::SparseMatrix<Scalar, Flags, StorageIndex>>;
|
| 56 |
+
#else
|
| 57 |
+
using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE;
|
| 58 |
+
template <typename Scalar, int Flags, typename StorageIndex>
|
| 59 |
+
using EigenMapSparseMatrix = Eigen::MappedSparseMatrix<Scalar, Flags, StorageIndex>;
|
| 60 |
+
#endif
|
| 61 |
+
|
| 62 |
+
// Matches Eigen::Map, Eigen::Ref, blocks, etc:
|
| 63 |
+
template <typename T>
|
| 64 |
+
using is_eigen_dense_map = all_of<is_template_base_of<Eigen::DenseBase, T>,
|
| 65 |
+
std::is_base_of<Eigen::MapBase<T, Eigen::ReadOnlyAccessors>, T>>;
|
| 66 |
+
template <typename T>
|
| 67 |
+
using is_eigen_mutable_map = std::is_base_of<Eigen::MapBase<T, Eigen::WriteAccessors>, T>;
|
| 68 |
+
template <typename T>
|
| 69 |
+
using is_eigen_dense_plain
|
| 70 |
+
= all_of<negation<is_eigen_dense_map<T>>, is_template_base_of<Eigen::PlainObjectBase, T>>;
|
| 71 |
+
template <typename T>
|
| 72 |
+
using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
|
| 73 |
+
// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This
|
| 74 |
+
// basically covers anything that can be assigned to a dense matrix but that don't have a typical
|
| 75 |
+
// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and
|
| 76 |
+
// SelfAdjointView fall into this category.
|
| 77 |
+
template <typename T>
|
| 78 |
+
using is_eigen_other
|
| 79 |
+
= all_of<is_template_base_of<Eigen::EigenBase, T>,
|
| 80 |
+
negation<any_of<is_eigen_dense_map<T>, is_eigen_dense_plain<T>, is_eigen_sparse<T>>>>;
|
| 81 |
+
|
| 82 |
+
// Captures numpy/eigen conformability status (returned by EigenProps::conformable()):
|
| 83 |
+
template <bool EigenRowMajor>
|
| 84 |
+
struct EigenConformable {
|
| 85 |
+
bool conformable = false;
|
| 86 |
+
EigenIndex rows = 0, cols = 0;
|
| 87 |
+
EigenDStride stride{0, 0}; // Only valid if negativestrides is false!
|
| 88 |
+
bool negativestrides = false; // If true, do not use stride!
|
| 89 |
+
|
| 90 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 91 |
+
EigenConformable(bool fits = false) : conformable{fits} {}
|
| 92 |
+
// Matrix type:
|
| 93 |
+
EigenConformable(EigenIndex r, EigenIndex c, EigenIndex rstride, EigenIndex cstride)
|
| 94 |
+
: conformable{true}, rows{r}, cols{c},
|
| 95 |
+
// TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity.
|
| 96 |
+
// http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747
|
| 97 |
+
stride{EigenRowMajor ? (rstride > 0 ? rstride : 0)
|
| 98 |
+
: (cstride > 0 ? cstride : 0) /* outer stride */,
|
| 99 |
+
EigenRowMajor ? (cstride > 0 ? cstride : 0)
|
| 100 |
+
: (rstride > 0 ? rstride : 0) /* inner stride */},
|
| 101 |
+
negativestrides{rstride < 0 || cstride < 0} {}
|
| 102 |
+
// Vector type:
|
| 103 |
+
EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride)
|
| 104 |
+
: EigenConformable(r, c, r == 1 ? c * stride : stride, c == 1 ? r : r * stride) {}
|
| 105 |
+
|
| 106 |
+
template <typename props>
|
| 107 |
+
bool stride_compatible() const {
|
| 108 |
+
// To have compatible strides, we need (on both dimensions) one of fully dynamic strides,
|
| 109 |
+
// matching strides, or a dimension size of 1 (in which case the stride value is
|
| 110 |
+
// irrelevant). Alternatively, if any dimension size is 0, the strides are not relevant
|
| 111 |
+
// (and numpy ≥ 1.23 sets the strides to 0 in that case, so we need to check explicitly).
|
| 112 |
+
if (negativestrides) {
|
| 113 |
+
return false;
|
| 114 |
+
}
|
| 115 |
+
if (rows == 0 || cols == 0) {
|
| 116 |
+
return true;
|
| 117 |
+
}
|
| 118 |
+
return (props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner()
|
| 119 |
+
|| (EigenRowMajor ? cols : rows) == 1)
|
| 120 |
+
&& (props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer()
|
| 121 |
+
|| (EigenRowMajor ? rows : cols) == 1);
|
| 122 |
+
}
|
| 123 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 124 |
+
operator bool() const { return conformable; }
|
| 125 |
+
};
|
| 126 |
+
|
| 127 |
+
template <typename Type>
|
| 128 |
+
struct eigen_extract_stride {
|
| 129 |
+
using type = Type;
|
| 130 |
+
};
|
| 131 |
+
template <typename PlainObjectType, int MapOptions, typename StrideType>
|
| 132 |
+
struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>> {
|
| 133 |
+
using type = StrideType;
|
| 134 |
+
};
|
| 135 |
+
template <typename PlainObjectType, int Options, typename StrideType>
|
| 136 |
+
struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> {
|
| 137 |
+
using type = StrideType;
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
// Helper struct for extracting information from an Eigen type
|
| 141 |
+
template <typename Type_>
|
| 142 |
+
struct EigenProps {
|
| 143 |
+
using Type = Type_;
|
| 144 |
+
using Scalar = typename Type::Scalar;
|
| 145 |
+
using StrideType = typename eigen_extract_stride<Type>::type;
|
| 146 |
+
static constexpr EigenIndex rows = Type::RowsAtCompileTime, cols = Type::ColsAtCompileTime,
|
| 147 |
+
size = Type::SizeAtCompileTime;
|
| 148 |
+
static constexpr bool row_major = Type::IsRowMajor,
|
| 149 |
+
vector
|
| 150 |
+
= Type::IsVectorAtCompileTime, // At least one dimension has fixed size 1
|
| 151 |
+
fixed_rows = rows != Eigen::Dynamic, fixed_cols = cols != Eigen::Dynamic,
|
| 152 |
+
fixed = size != Eigen::Dynamic, // Fully-fixed size
|
| 153 |
+
dynamic = !fixed_rows && !fixed_cols; // Fully-dynamic size
|
| 154 |
+
|
| 155 |
+
template <EigenIndex i, EigenIndex ifzero>
|
| 156 |
+
using if_zero = std::integral_constant<EigenIndex, i == 0 ? ifzero : i>;
|
| 157 |
+
static constexpr EigenIndex inner_stride
|
| 158 |
+
= if_zero<StrideType::InnerStrideAtCompileTime, 1>::value,
|
| 159 |
+
outer_stride = if_zero < StrideType::OuterStrideAtCompileTime,
|
| 160 |
+
vector ? size
|
| 161 |
+
: row_major ? cols
|
| 162 |
+
: rows > ::value;
|
| 163 |
+
static constexpr bool dynamic_stride
|
| 164 |
+
= inner_stride == Eigen::Dynamic && outer_stride == Eigen::Dynamic;
|
| 165 |
+
static constexpr bool requires_row_major
|
| 166 |
+
= !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1;
|
| 167 |
+
static constexpr bool requires_col_major
|
| 168 |
+
= !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1;
|
| 169 |
+
|
| 170 |
+
// Takes an input array and determines whether we can make it fit into the Eigen type. If
|
| 171 |
+
// the array is a vector, we attempt to fit it into either an Eigen 1xN or Nx1 vector
|
| 172 |
+
// (preferring the latter if it will fit in either, i.e. for a fully dynamic matrix type).
|
| 173 |
+
static EigenConformable<row_major> conformable(const array &a) {
|
| 174 |
+
const auto dims = a.ndim();
|
| 175 |
+
if (dims < 1 || dims > 2) {
|
| 176 |
+
return false;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
if (dims == 2) { // Matrix type: require exact match (or dynamic)
|
| 180 |
+
|
| 181 |
+
EigenIndex np_rows = a.shape(0), np_cols = a.shape(1),
|
| 182 |
+
np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
|
| 183 |
+
np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
|
| 184 |
+
if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols)) {
|
| 185 |
+
return false;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
return {np_rows, np_cols, np_rstride, np_cstride};
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
// Otherwise we're storing an n-vector. Only one of the strides will be used, but
|
| 192 |
+
// whichever is used, we want the (single) numpy stride value.
|
| 193 |
+
const EigenIndex n = a.shape(0),
|
| 194 |
+
stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
|
| 195 |
+
|
| 196 |
+
if (vector) { // Eigen type is a compile-time vector
|
| 197 |
+
if (fixed && size != n) {
|
| 198 |
+
return false; // Vector size mismatch
|
| 199 |
+
}
|
| 200 |
+
return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride};
|
| 201 |
+
}
|
| 202 |
+
if (fixed) {
|
| 203 |
+
// The type has a fixed size, but is not a vector: abort
|
| 204 |
+
return false;
|
| 205 |
+
}
|
| 206 |
+
if (fixed_cols) {
|
| 207 |
+
// Since this isn't a vector, cols must be != 1. We allow this only if it exactly
|
| 208 |
+
// equals the number of elements (rows is Dynamic, and so 1 row is allowed).
|
| 209 |
+
if (cols != n) {
|
| 210 |
+
return false;
|
| 211 |
+
}
|
| 212 |
+
return {1, n, stride};
|
| 213 |
+
} // Otherwise it's either fully dynamic, or column dynamic; both become a column vector
|
| 214 |
+
if (fixed_rows && rows != n) {
|
| 215 |
+
return false;
|
| 216 |
+
}
|
| 217 |
+
return {n, 1, stride};
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
static constexpr bool show_writeable
|
| 221 |
+
= is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
|
| 222 |
+
static constexpr bool show_order = is_eigen_dense_map<Type>::value;
|
| 223 |
+
static constexpr bool show_c_contiguous = show_order && requires_row_major;
|
| 224 |
+
static constexpr bool show_f_contiguous
|
| 225 |
+
= !show_c_contiguous && show_order && requires_col_major;
|
| 226 |
+
|
| 227 |
+
static constexpr auto descriptor
|
| 228 |
+
= const_name("numpy.ndarray[") + npy_format_descriptor<Scalar>::name + const_name("[")
|
| 229 |
+
+ const_name<fixed_rows>(const_name<(size_t) rows>(), const_name("m")) + const_name(", ")
|
| 230 |
+
+ const_name<fixed_cols>(const_name<(size_t) cols>(), const_name("n")) + const_name("]")
|
| 231 |
+
+
|
| 232 |
+
// For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to
|
| 233 |
+
// be satisfied: writeable=True (for a mutable reference), and, depending on the map's
|
| 234 |
+
// stride options, possibly f_contiguous or c_contiguous. We include them in the
|
| 235 |
+
// descriptor output to provide some hint as to why a TypeError is occurring (otherwise
|
| 236 |
+
// it can be confusing to see that a function accepts a 'numpy.ndarray[float64[3,2]]' and
|
| 237 |
+
// an error message that you *gave* a numpy.ndarray of the right type and dimensions.
|
| 238 |
+
const_name<show_writeable>(", flags.writeable", "")
|
| 239 |
+
+ const_name<show_c_contiguous>(", flags.c_contiguous", "")
|
| 240 |
+
+ const_name<show_f_contiguous>(", flags.f_contiguous", "") + const_name("]");
|
| 241 |
+
};
|
| 242 |
+
|
| 243 |
+
// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
|
| 244 |
+
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
|
| 245 |
+
template <typename props>
|
| 246 |
+
handle
|
| 247 |
+
eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
|
| 248 |
+
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
|
| 249 |
+
array a;
|
| 250 |
+
if (props::vector) {
|
| 251 |
+
a = array({src.size()}, {elem_size * src.innerStride()}, src.data(), base);
|
| 252 |
+
} else {
|
| 253 |
+
a = array({src.rows(), src.cols()},
|
| 254 |
+
{elem_size * src.rowStride(), elem_size * src.colStride()},
|
| 255 |
+
src.data(),
|
| 256 |
+
base);
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
if (!writeable) {
|
| 260 |
+
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
return a.release();
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
// Takes an lvalue ref to some Eigen type and a (python) base object, creating a numpy array that
|
| 267 |
+
// reference the Eigen object's data with `base` as the python-registered base class (if omitted,
|
| 268 |
+
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
|
| 269 |
+
// non-writeable if the given type is const.
|
| 270 |
+
template <typename props, typename Type>
|
| 271 |
+
handle eigen_ref_array(Type &src, handle parent = none()) {
|
| 272 |
+
// none here is to get past array's should-we-copy detection, which currently always
|
| 273 |
+
// copies when there is no base. Setting the base to None should be harmless.
|
| 274 |
+
return eigen_array_cast<props>(src, parent, !std::is_const<Type>::value);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// Takes a pointer to some dense, plain Eigen type, builds a capsule around it, then returns a
|
| 278 |
+
// numpy array that references the encapsulated data with a python-side reference to the capsule to
|
| 279 |
+
// tie its destruction to that of any dependent python objects. Const-ness is determined by
|
| 280 |
+
// whether or not the Type of the pointer given is const.
|
| 281 |
+
template <typename props, typename Type, typename = enable_if_t<is_eigen_dense_plain<Type>::value>>
|
| 282 |
+
handle eigen_encapsulate(Type *src) {
|
| 283 |
+
capsule base(src, [](void *o) { delete static_cast<Type *>(o); });
|
| 284 |
+
return eigen_ref_array<props>(*src, base);
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
// Type caster for regular, dense matrix types (e.g. MatrixXd), but not maps/refs/etc. of dense
|
| 288 |
+
// types.
|
| 289 |
+
template <typename Type>
|
| 290 |
+
struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
|
| 291 |
+
using Scalar = typename Type::Scalar;
|
| 292 |
+
static_assert(!std::is_pointer<Scalar>::value,
|
| 293 |
+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
|
| 294 |
+
using props = EigenProps<Type>;
|
| 295 |
+
|
| 296 |
+
bool load(handle src, bool convert) {
|
| 297 |
+
// If we're in no-convert mode, only load if given an array of the correct type
|
| 298 |
+
if (!convert && !isinstance<array_t<Scalar>>(src)) {
|
| 299 |
+
return false;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
// Coerce into an array, but don't do type conversion yet; the copy below handles it.
|
| 303 |
+
auto buf = array::ensure(src);
|
| 304 |
+
|
| 305 |
+
if (!buf) {
|
| 306 |
+
return false;
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
auto dims = buf.ndim();
|
| 310 |
+
if (dims < 1 || dims > 2) {
|
| 311 |
+
return false;
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
auto fits = props::conformable(buf);
|
| 315 |
+
if (!fits) {
|
| 316 |
+
return false;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
// Allocate the new type, then build a numpy reference into it
|
| 320 |
+
value = Type(fits.rows, fits.cols);
|
| 321 |
+
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
|
| 322 |
+
if (dims == 1) {
|
| 323 |
+
ref = ref.squeeze();
|
| 324 |
+
} else if (ref.ndim() == 1) {
|
| 325 |
+
buf = buf.squeeze();
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
|
| 329 |
+
|
| 330 |
+
if (result < 0) { // Copy failed!
|
| 331 |
+
PyErr_Clear();
|
| 332 |
+
return false;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
return true;
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
private:
|
| 339 |
+
// Cast implementation
|
| 340 |
+
template <typename CType>
|
| 341 |
+
static handle cast_impl(CType *src, return_value_policy policy, handle parent) {
|
| 342 |
+
switch (policy) {
|
| 343 |
+
case return_value_policy::take_ownership:
|
| 344 |
+
case return_value_policy::automatic:
|
| 345 |
+
return eigen_encapsulate<props>(src);
|
| 346 |
+
case return_value_policy::move:
|
| 347 |
+
return eigen_encapsulate<props>(new CType(std::move(*src)));
|
| 348 |
+
case return_value_policy::copy:
|
| 349 |
+
return eigen_array_cast<props>(*src);
|
| 350 |
+
case return_value_policy::reference:
|
| 351 |
+
case return_value_policy::automatic_reference:
|
| 352 |
+
return eigen_ref_array<props>(*src);
|
| 353 |
+
case return_value_policy::reference_internal:
|
| 354 |
+
return eigen_ref_array<props>(*src, parent);
|
| 355 |
+
default:
|
| 356 |
+
throw cast_error("unhandled return_value_policy: should not happen!");
|
| 357 |
+
};
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
public:
|
| 361 |
+
// Normal returned non-reference, non-const value:
|
| 362 |
+
static handle cast(Type &&src, return_value_policy /* policy */, handle parent) {
|
| 363 |
+
return cast_impl(&src, return_value_policy::move, parent);
|
| 364 |
+
}
|
| 365 |
+
// If you return a non-reference const, we mark the numpy array readonly:
|
| 366 |
+
static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) {
|
| 367 |
+
return cast_impl(&src, return_value_policy::move, parent);
|
| 368 |
+
}
|
| 369 |
+
// lvalue reference return; default (automatic) becomes copy
|
| 370 |
+
static handle cast(Type &src, return_value_policy policy, handle parent) {
|
| 371 |
+
if (policy == return_value_policy::automatic
|
| 372 |
+
|| policy == return_value_policy::automatic_reference) {
|
| 373 |
+
policy = return_value_policy::copy;
|
| 374 |
+
}
|
| 375 |
+
return cast_impl(&src, policy, parent);
|
| 376 |
+
}
|
| 377 |
+
// const lvalue reference return; default (automatic) becomes copy
|
| 378 |
+
static handle cast(const Type &src, return_value_policy policy, handle parent) {
|
| 379 |
+
if (policy == return_value_policy::automatic
|
| 380 |
+
|| policy == return_value_policy::automatic_reference) {
|
| 381 |
+
policy = return_value_policy::copy;
|
| 382 |
+
}
|
| 383 |
+
return cast(&src, policy, parent);
|
| 384 |
+
}
|
| 385 |
+
// non-const pointer return
|
| 386 |
+
static handle cast(Type *src, return_value_policy policy, handle parent) {
|
| 387 |
+
return cast_impl(src, policy, parent);
|
| 388 |
+
}
|
| 389 |
+
// const pointer return
|
| 390 |
+
static handle cast(const Type *src, return_value_policy policy, handle parent) {
|
| 391 |
+
return cast_impl(src, policy, parent);
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
static constexpr auto name = props::descriptor;
|
| 395 |
+
|
| 396 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 397 |
+
operator Type *() { return &value; }
|
| 398 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 399 |
+
operator Type &() { return value; }
|
| 400 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 401 |
+
operator Type &&() && { return std::move(value); }
|
| 402 |
+
template <typename T>
|
| 403 |
+
using cast_op_type = movable_cast_op_type<T>;
|
| 404 |
+
|
| 405 |
+
private:
|
| 406 |
+
Type value;
|
| 407 |
+
};
|
| 408 |
+
|
| 409 |
+
// Base class for casting reference/map/block/etc. objects back to python.
|
| 410 |
+
template <typename MapType>
|
| 411 |
+
struct eigen_map_caster {
|
| 412 |
+
static_assert(!std::is_pointer<typename MapType::Scalar>::value,
|
| 413 |
+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
|
| 414 |
+
|
| 415 |
+
private:
|
| 416 |
+
using props = EigenProps<MapType>;
|
| 417 |
+
|
| 418 |
+
public:
|
| 419 |
+
// Directly referencing a ref/map's data is a bit dangerous (whatever the map/ref points to has
|
| 420 |
+
// to stay around), but we'll allow it under the assumption that you know what you're doing
|
| 421 |
+
// (and have an appropriate keep_alive in place). We return a numpy array pointing directly at
|
| 422 |
+
// the ref's data (The numpy array ends up read-only if the ref was to a const matrix type.)
|
| 423 |
+
// Note that this means you need to ensure you don't destroy the object in some other way (e.g.
|
| 424 |
+
// with an appropriate keep_alive, or with a reference to a statically allocated matrix).
|
| 425 |
+
static handle cast(const MapType &src, return_value_policy policy, handle parent) {
|
| 426 |
+
switch (policy) {
|
| 427 |
+
case return_value_policy::copy:
|
| 428 |
+
return eigen_array_cast<props>(src);
|
| 429 |
+
case return_value_policy::reference_internal:
|
| 430 |
+
return eigen_array_cast<props>(src, parent, is_eigen_mutable_map<MapType>::value);
|
| 431 |
+
case return_value_policy::reference:
|
| 432 |
+
case return_value_policy::automatic:
|
| 433 |
+
case return_value_policy::automatic_reference:
|
| 434 |
+
return eigen_array_cast<props>(src, none(), is_eigen_mutable_map<MapType>::value);
|
| 435 |
+
default:
|
| 436 |
+
// move, take_ownership don't make any sense for a ref/map:
|
| 437 |
+
pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type");
|
| 438 |
+
}
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
static constexpr auto name = props::descriptor;
|
| 442 |
+
|
| 443 |
+
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
|
| 444 |
+
// types but not bound arguments). We still provide them (with an explicitly delete) so that
|
| 445 |
+
// you end up here if you try anyway.
|
| 446 |
+
bool load(handle, bool) = delete;
|
| 447 |
+
operator MapType() = delete;
|
| 448 |
+
template <typename>
|
| 449 |
+
using cast_op_type = MapType;
|
| 450 |
+
};
|
| 451 |
+
|
| 452 |
+
// We can return any map-like object (but can only load Refs, specialized next):
|
| 453 |
+
template <typename Type>
|
| 454 |
+
struct type_caster<Type, enable_if_t<is_eigen_dense_map<Type>::value>> : eigen_map_caster<Type> {};
|
| 455 |
+
|
| 456 |
+
// Loader for Ref<...> arguments. See the documentation for info on how to make this work without
|
| 457 |
+
// copying (it requires some extra effort in many cases).
|
| 458 |
+
template <typename PlainObjectType, typename StrideType>
|
| 459 |
+
struct type_caster<
|
| 460 |
+
Eigen::Ref<PlainObjectType, 0, StrideType>,
|
| 461 |
+
enable_if_t<is_eigen_dense_map<Eigen::Ref<PlainObjectType, 0, StrideType>>::value>>
|
| 462 |
+
: public eigen_map_caster<Eigen::Ref<PlainObjectType, 0, StrideType>> {
|
| 463 |
+
private:
|
| 464 |
+
using Type = Eigen::Ref<PlainObjectType, 0, StrideType>;
|
| 465 |
+
using props = EigenProps<Type>;
|
| 466 |
+
using Scalar = typename props::Scalar;
|
| 467 |
+
static_assert(!std::is_pointer<Scalar>::value,
|
| 468 |
+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
|
| 469 |
+
using MapType = Eigen::Map<PlainObjectType, 0, StrideType>;
|
| 470 |
+
using Array
|
| 471 |
+
= array_t<Scalar,
|
| 472 |
+
array::forcecast
|
| 473 |
+
| ((props::row_major ? props::inner_stride : props::outer_stride) == 1
|
| 474 |
+
? array::c_style
|
| 475 |
+
: (props::row_major ? props::outer_stride : props::inner_stride) == 1
|
| 476 |
+
? array::f_style
|
| 477 |
+
: 0)>;
|
| 478 |
+
static constexpr bool need_writeable = is_eigen_mutable_map<Type>::value;
|
| 479 |
+
// Delay construction (these have no default constructor)
|
| 480 |
+
std::unique_ptr<MapType> map;
|
| 481 |
+
std::unique_ptr<Type> ref;
|
| 482 |
+
// Our array. When possible, this is just a numpy array pointing to the source data, but
|
| 483 |
+
// sometimes we can't avoid copying (e.g. input is not a numpy array at all, has an
|
| 484 |
+
// incompatible layout, or is an array of a type that needs to be converted). Using a numpy
|
| 485 |
+
// temporary (rather than an Eigen temporary) saves an extra copy when we need both type
|
| 486 |
+
// conversion and storage order conversion. (Note that we refuse to use this temporary copy
|
| 487 |
+
// when loading an argument for a Ref<M> with M non-const, i.e. a read-write reference).
|
| 488 |
+
Array copy_or_ref;
|
| 489 |
+
|
| 490 |
+
public:
|
| 491 |
+
bool load(handle src, bool convert) {
|
| 492 |
+
// First check whether what we have is already an array of the right type. If not, we
|
| 493 |
+
// can't avoid a copy (because the copy is also going to do type conversion).
|
| 494 |
+
bool need_copy = !isinstance<Array>(src);
|
| 495 |
+
|
| 496 |
+
EigenConformable<props::row_major> fits;
|
| 497 |
+
if (!need_copy) {
|
| 498 |
+
// We don't need a converting copy, but we also need to check whether the strides are
|
| 499 |
+
// compatible with the Ref's stride requirements
|
| 500 |
+
auto aref = reinterpret_borrow<Array>(src);
|
| 501 |
+
|
| 502 |
+
if (aref && (!need_writeable || aref.writeable())) {
|
| 503 |
+
fits = props::conformable(aref);
|
| 504 |
+
if (!fits) {
|
| 505 |
+
return false; // Incompatible dimensions
|
| 506 |
+
}
|
| 507 |
+
if (!fits.template stride_compatible<props>()) {
|
| 508 |
+
need_copy = true;
|
| 509 |
+
} else {
|
| 510 |
+
copy_or_ref = std::move(aref);
|
| 511 |
+
}
|
| 512 |
+
} else {
|
| 513 |
+
need_copy = true;
|
| 514 |
+
}
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
if (need_copy) {
|
| 518 |
+
// We need to copy: If we need a mutable reference, or we're not supposed to convert
|
| 519 |
+
// (either because we're in the no-convert overload pass, or because we're explicitly
|
| 520 |
+
// instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
|
| 521 |
+
if (!convert || need_writeable) {
|
| 522 |
+
return false;
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
Array copy = Array::ensure(src);
|
| 526 |
+
if (!copy) {
|
| 527 |
+
return false;
|
| 528 |
+
}
|
| 529 |
+
fits = props::conformable(copy);
|
| 530 |
+
if (!fits || !fits.template stride_compatible<props>()) {
|
| 531 |
+
return false;
|
| 532 |
+
}
|
| 533 |
+
copy_or_ref = std::move(copy);
|
| 534 |
+
loader_life_support::add_patient(copy_or_ref);
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
ref.reset();
|
| 538 |
+
map.reset(new MapType(data(copy_or_ref),
|
| 539 |
+
fits.rows,
|
| 540 |
+
fits.cols,
|
| 541 |
+
make_stride(fits.stride.outer(), fits.stride.inner())));
|
| 542 |
+
ref.reset(new Type(*map));
|
| 543 |
+
|
| 544 |
+
return true;
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 548 |
+
operator Type *() { return ref.get(); }
|
| 549 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 550 |
+
operator Type &() { return *ref; }
|
| 551 |
+
template <typename _T>
|
| 552 |
+
using cast_op_type = pybind11::detail::cast_op_type<_T>;
|
| 553 |
+
|
| 554 |
+
private:
|
| 555 |
+
template <typename T = Type, enable_if_t<is_eigen_mutable_map<T>::value, int> = 0>
|
| 556 |
+
Scalar *data(Array &a) {
|
| 557 |
+
return a.mutable_data();
|
| 558 |
+
}
|
| 559 |
+
|
| 560 |
+
template <typename T = Type, enable_if_t<!is_eigen_mutable_map<T>::value, int> = 0>
|
| 561 |
+
const Scalar *data(Array &a) {
|
| 562 |
+
return a.data();
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
// Attempt to figure out a constructor of `Stride` that will work.
|
| 566 |
+
// If both strides are fixed, use a default constructor:
|
| 567 |
+
template <typename S>
|
| 568 |
+
using stride_ctor_default = bool_constant<S::InnerStrideAtCompileTime != Eigen::Dynamic
|
| 569 |
+
&& S::OuterStrideAtCompileTime != Eigen::Dynamic
|
| 570 |
+
&& std::is_default_constructible<S>::value>;
|
| 571 |
+
// Otherwise, if there is a two-index constructor, assume it is (outer,inner) like
|
| 572 |
+
// Eigen::Stride, and use it:
|
| 573 |
+
template <typename S>
|
| 574 |
+
using stride_ctor_dual
|
| 575 |
+
= bool_constant<!stride_ctor_default<S>::value
|
| 576 |
+
&& std::is_constructible<S, EigenIndex, EigenIndex>::value>;
|
| 577 |
+
// Otherwise, if there is a one-index constructor, and just one of the strides is dynamic, use
|
| 578 |
+
// it (passing whichever stride is dynamic).
|
| 579 |
+
template <typename S>
|
| 580 |
+
using stride_ctor_outer
|
| 581 |
+
= bool_constant<!any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value
|
| 582 |
+
&& S::OuterStrideAtCompileTime == Eigen::Dynamic
|
| 583 |
+
&& S::InnerStrideAtCompileTime != Eigen::Dynamic
|
| 584 |
+
&& std::is_constructible<S, EigenIndex>::value>;
|
| 585 |
+
template <typename S>
|
| 586 |
+
using stride_ctor_inner
|
| 587 |
+
= bool_constant<!any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value
|
| 588 |
+
&& S::InnerStrideAtCompileTime == Eigen::Dynamic
|
| 589 |
+
&& S::OuterStrideAtCompileTime != Eigen::Dynamic
|
| 590 |
+
&& std::is_constructible<S, EigenIndex>::value>;
|
| 591 |
+
|
| 592 |
+
template <typename S = StrideType, enable_if_t<stride_ctor_default<S>::value, int> = 0>
|
| 593 |
+
static S make_stride(EigenIndex, EigenIndex) {
|
| 594 |
+
return S();
|
| 595 |
+
}
|
| 596 |
+
template <typename S = StrideType, enable_if_t<stride_ctor_dual<S>::value, int> = 0>
|
| 597 |
+
static S make_stride(EigenIndex outer, EigenIndex inner) {
|
| 598 |
+
return S(outer, inner);
|
| 599 |
+
}
|
| 600 |
+
template <typename S = StrideType, enable_if_t<stride_ctor_outer<S>::value, int> = 0>
|
| 601 |
+
static S make_stride(EigenIndex outer, EigenIndex) {
|
| 602 |
+
return S(outer);
|
| 603 |
+
}
|
| 604 |
+
template <typename S = StrideType, enable_if_t<stride_ctor_inner<S>::value, int> = 0>
|
| 605 |
+
static S make_stride(EigenIndex, EigenIndex inner) {
|
| 606 |
+
return S(inner);
|
| 607 |
+
}
|
| 608 |
+
};
|
| 609 |
+
|
| 610 |
+
// type_caster for special matrix types (e.g. DiagonalMatrix), which are EigenBase, but not
|
| 611 |
+
// EigenDense (i.e. they don't have a data(), at least not with the usual matrix layout).
|
| 612 |
+
// load() is not supported, but we can cast them into the python domain by first copying to a
|
| 613 |
+
// regular Eigen::Matrix, then casting that.
|
| 614 |
+
template <typename Type>
|
| 615 |
+
struct type_caster<Type, enable_if_t<is_eigen_other<Type>::value>> {
|
| 616 |
+
static_assert(!std::is_pointer<typename Type::Scalar>::value,
|
| 617 |
+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
|
| 618 |
+
|
| 619 |
+
protected:
|
| 620 |
+
using Matrix
|
| 621 |
+
= Eigen::Matrix<typename Type::Scalar, Type::RowsAtCompileTime, Type::ColsAtCompileTime>;
|
| 622 |
+
using props = EigenProps<Matrix>;
|
| 623 |
+
|
| 624 |
+
public:
|
| 625 |
+
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
|
| 626 |
+
handle h = eigen_encapsulate<props>(new Matrix(src));
|
| 627 |
+
return h;
|
| 628 |
+
}
|
| 629 |
+
static handle cast(const Type *src, return_value_policy policy, handle parent) {
|
| 630 |
+
return cast(*src, policy, parent);
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
static constexpr auto name = props::descriptor;
|
| 634 |
+
|
| 635 |
+
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
|
| 636 |
+
// types but not bound arguments). We still provide them (with an explicitly delete) so that
|
| 637 |
+
// you end up here if you try anyway.
|
| 638 |
+
bool load(handle, bool) = delete;
|
| 639 |
+
operator Type() = delete;
|
| 640 |
+
template <typename>
|
| 641 |
+
using cast_op_type = Type;
|
| 642 |
+
};
|
| 643 |
+
|
| 644 |
+
template <typename Type>
|
| 645 |
+
struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
|
| 646 |
+
using Scalar = typename Type::Scalar;
|
| 647 |
+
static_assert(!std::is_pointer<Scalar>::value,
|
| 648 |
+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
|
| 649 |
+
using StorageIndex = remove_reference_t<decltype(*std::declval<Type>().outerIndexPtr())>;
|
| 650 |
+
using Index = typename Type::Index;
|
| 651 |
+
static constexpr bool rowMajor = Type::IsRowMajor;
|
| 652 |
+
|
| 653 |
+
bool load(handle src, bool) {
|
| 654 |
+
if (!src) {
|
| 655 |
+
return false;
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
auto obj = reinterpret_borrow<object>(src);
|
| 659 |
+
object sparse_module = module_::import("scipy.sparse");
|
| 660 |
+
object matrix_type = sparse_module.attr(rowMajor ? "csr_matrix" : "csc_matrix");
|
| 661 |
+
|
| 662 |
+
if (!type::handle_of(obj).is(matrix_type)) {
|
| 663 |
+
try {
|
| 664 |
+
obj = matrix_type(obj);
|
| 665 |
+
} catch (const error_already_set &) {
|
| 666 |
+
return false;
|
| 667 |
+
}
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
auto values = array_t<Scalar>((object) obj.attr("data"));
|
| 671 |
+
auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
|
| 672 |
+
auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
|
| 673 |
+
auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
|
| 674 |
+
auto nnz = obj.attr("nnz").cast<Index>();
|
| 675 |
+
|
| 676 |
+
if (!values || !innerIndices || !outerIndices) {
|
| 677 |
+
return false;
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
value = EigenMapSparseMatrix<Scalar,
|
| 681 |
+
Type::Flags &(Eigen::RowMajor | Eigen::ColMajor),
|
| 682 |
+
StorageIndex>(shape[0].cast<Index>(),
|
| 683 |
+
shape[1].cast<Index>(),
|
| 684 |
+
std::move(nnz),
|
| 685 |
+
outerIndices.mutable_data(),
|
| 686 |
+
innerIndices.mutable_data(),
|
| 687 |
+
values.mutable_data());
|
| 688 |
+
|
| 689 |
+
return true;
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
|
| 693 |
+
const_cast<Type &>(src).makeCompressed();
|
| 694 |
+
|
| 695 |
+
object matrix_type
|
| 696 |
+
= module_::import("scipy.sparse").attr(rowMajor ? "csr_matrix" : "csc_matrix");
|
| 697 |
+
|
| 698 |
+
array data(src.nonZeros(), src.valuePtr());
|
| 699 |
+
array outerIndices((rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
|
| 700 |
+
array innerIndices(src.nonZeros(), src.innerIndexPtr());
|
| 701 |
+
|
| 702 |
+
return matrix_type(pybind11::make_tuple(
|
| 703 |
+
std::move(data), std::move(innerIndices), std::move(outerIndices)),
|
| 704 |
+
pybind11::make_tuple(src.rows(), src.cols()))
|
| 705 |
+
.release();
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
PYBIND11_TYPE_CASTER(Type,
|
| 709 |
+
const_name<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[",
|
| 710 |
+
"scipy.sparse.csc_matrix[")
|
| 711 |
+
+ npy_format_descriptor<Scalar>::name + const_name("]"));
|
| 712 |
+
};
|
| 713 |
+
|
| 714 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 715 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/eigen/tensor.h
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/eigen/tensor.h: Transparent conversion for Eigen tensors
|
| 3 |
+
|
| 4 |
+
All rights reserved. Use of this source code is governed by a
|
| 5 |
+
BSD-style license that can be found in the LICENSE file.
|
| 6 |
+
*/
|
| 7 |
+
|
| 8 |
+
#pragma once
|
| 9 |
+
|
| 10 |
+
#include <pybind11/numpy.h>
|
| 11 |
+
|
| 12 |
+
#include "common.h"
|
| 13 |
+
|
| 14 |
+
#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)
|
| 15 |
+
static_assert(__GNUC__ > 5, "Eigen Tensor support in pybind11 requires GCC > 5.0");
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
// Disable warnings for Eigen
|
| 19 |
+
PYBIND11_WARNING_PUSH
|
| 20 |
+
PYBIND11_WARNING_DISABLE_MSVC(4554)
|
| 21 |
+
PYBIND11_WARNING_DISABLE_MSVC(4127)
|
| 22 |
+
#if defined(__MINGW32__)
|
| 23 |
+
PYBIND11_WARNING_DISABLE_GCC("-Wmaybe-uninitialized")
|
| 24 |
+
#endif
|
| 25 |
+
|
| 26 |
+
#include <unsupported/Eigen/CXX11/Tensor>
|
| 27 |
+
|
| 28 |
+
PYBIND11_WARNING_POP
|
| 29 |
+
|
| 30 |
+
static_assert(EIGEN_VERSION_AT_LEAST(3, 3, 0),
|
| 31 |
+
"Eigen Tensor support in pybind11 requires Eigen >= 3.3.0");
|
| 32 |
+
|
| 33 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 34 |
+
|
| 35 |
+
PYBIND11_WARNING_DISABLE_MSVC(4127)
|
| 36 |
+
|
| 37 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 38 |
+
|
| 39 |
+
inline bool is_tensor_aligned(const void *data) {
|
| 40 |
+
return (reinterpret_cast<std::size_t>(data) % EIGEN_DEFAULT_ALIGN_BYTES) == 0;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
template <typename T>
|
| 44 |
+
constexpr int compute_array_flag_from_tensor() {
|
| 45 |
+
static_assert((static_cast<int>(T::Layout) == static_cast<int>(Eigen::RowMajor))
|
| 46 |
+
|| (static_cast<int>(T::Layout) == static_cast<int>(Eigen::ColMajor)),
|
| 47 |
+
"Layout must be row or column major");
|
| 48 |
+
return (static_cast<int>(T::Layout) == static_cast<int>(Eigen::RowMajor)) ? array::c_style
|
| 49 |
+
: array::f_style;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
template <typename T>
|
| 53 |
+
struct eigen_tensor_helper {};
|
| 54 |
+
|
| 55 |
+
template <typename Scalar_, int NumIndices_, int Options_, typename IndexType>
|
| 56 |
+
struct eigen_tensor_helper<Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>> {
|
| 57 |
+
using Type = Eigen::Tensor<Scalar_, NumIndices_, Options_, IndexType>;
|
| 58 |
+
using ValidType = void;
|
| 59 |
+
|
| 60 |
+
static Eigen::DSizes<typename Type::Index, Type::NumIndices> get_shape(const Type &f) {
|
| 61 |
+
return f.dimensions();
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
static constexpr bool
|
| 65 |
+
is_correct_shape(const Eigen::DSizes<typename Type::Index, Type::NumIndices> & /*shape*/) {
|
| 66 |
+
return true;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
template <typename T>
|
| 70 |
+
struct helper {};
|
| 71 |
+
|
| 72 |
+
template <size_t... Is>
|
| 73 |
+
struct helper<index_sequence<Is...>> {
|
| 74 |
+
static constexpr auto value = ::pybind11::detail::concat(const_name(((void) Is, "?"))...);
|
| 75 |
+
};
|
| 76 |
+
|
| 77 |
+
static constexpr auto dimensions_descriptor
|
| 78 |
+
= helper<decltype(make_index_sequence<Type::NumIndices>())>::value;
|
| 79 |
+
|
| 80 |
+
template <typename... Args>
|
| 81 |
+
static Type *alloc(Args &&...args) {
|
| 82 |
+
return new Type(std::forward<Args>(args)...);
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
static void free(Type *tensor) { delete tensor; }
|
| 86 |
+
};
|
| 87 |
+
|
| 88 |
+
template <typename Scalar_, typename std::ptrdiff_t... Indices, int Options_, typename IndexType>
|
| 89 |
+
struct eigen_tensor_helper<
|
| 90 |
+
Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>> {
|
| 91 |
+
using Type = Eigen::TensorFixedSize<Scalar_, Eigen::Sizes<Indices...>, Options_, IndexType>;
|
| 92 |
+
using ValidType = void;
|
| 93 |
+
|
| 94 |
+
static constexpr Eigen::DSizes<typename Type::Index, Type::NumIndices>
|
| 95 |
+
get_shape(const Type & /*f*/) {
|
| 96 |
+
return get_shape();
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
static constexpr Eigen::DSizes<typename Type::Index, Type::NumIndices> get_shape() {
|
| 100 |
+
return Eigen::DSizes<typename Type::Index, Type::NumIndices>(Indices...);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
static bool
|
| 104 |
+
is_correct_shape(const Eigen::DSizes<typename Type::Index, Type::NumIndices> &shape) {
|
| 105 |
+
return get_shape() == shape;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
static constexpr auto dimensions_descriptor
|
| 109 |
+
= ::pybind11::detail::concat(const_name<Indices>()...);
|
| 110 |
+
|
| 111 |
+
template <typename... Args>
|
| 112 |
+
static Type *alloc(Args &&...args) {
|
| 113 |
+
Eigen::aligned_allocator<Type> allocator;
|
| 114 |
+
return ::new (allocator.allocate(1)) Type(std::forward<Args>(args)...);
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
static void free(Type *tensor) {
|
| 118 |
+
Eigen::aligned_allocator<Type> allocator;
|
| 119 |
+
tensor->~Type();
|
| 120 |
+
allocator.deallocate(tensor, 1);
|
| 121 |
+
}
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
template <typename Type, bool ShowDetails, bool NeedsWriteable = false>
|
| 125 |
+
struct get_tensor_descriptor {
|
| 126 |
+
static constexpr auto details
|
| 127 |
+
= const_name<NeedsWriteable>(", flags.writeable", "")
|
| 128 |
+
+ const_name<static_cast<int>(Type::Layout) == static_cast<int>(Eigen::RowMajor)>(
|
| 129 |
+
", flags.c_contiguous", ", flags.f_contiguous");
|
| 130 |
+
static constexpr auto value
|
| 131 |
+
= const_name("numpy.ndarray[") + npy_format_descriptor<typename Type::Scalar>::name
|
| 132 |
+
+ const_name("[") + eigen_tensor_helper<remove_cv_t<Type>>::dimensions_descriptor
|
| 133 |
+
+ const_name("]") + const_name<ShowDetails>(details, const_name("")) + const_name("]");
|
| 134 |
+
};
|
| 135 |
+
|
| 136 |
+
// When EIGEN_AVOID_STL_ARRAY is defined, Eigen::DSizes<T, 0> does not have the begin() member
|
| 137 |
+
// function. Falling back to a simple loop works around this issue.
|
| 138 |
+
//
|
| 139 |
+
// We need to disable the type-limits warning for the inner loop when size = 0.
|
| 140 |
+
|
| 141 |
+
PYBIND11_WARNING_PUSH
|
| 142 |
+
PYBIND11_WARNING_DISABLE_GCC("-Wtype-limits")
|
| 143 |
+
|
| 144 |
+
template <typename T, int size>
|
| 145 |
+
std::vector<T> convert_dsizes_to_vector(const Eigen::DSizes<T, size> &arr) {
|
| 146 |
+
std::vector<T> result(size);
|
| 147 |
+
|
| 148 |
+
for (size_t i = 0; i < size; i++) {
|
| 149 |
+
result[i] = arr[i];
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
return result;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
template <typename T, int size>
|
| 156 |
+
Eigen::DSizes<T, size> get_shape_for_array(const array &arr) {
|
| 157 |
+
Eigen::DSizes<T, size> result;
|
| 158 |
+
const T *shape = arr.shape();
|
| 159 |
+
for (size_t i = 0; i < size; i++) {
|
| 160 |
+
result[i] = shape[i];
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
return result;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
PYBIND11_WARNING_POP
|
| 167 |
+
|
| 168 |
+
template <typename Type>
|
| 169 |
+
struct type_caster<Type, typename eigen_tensor_helper<Type>::ValidType> {
|
| 170 |
+
static_assert(!std::is_pointer<typename Type::Scalar>::value,
|
| 171 |
+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
|
| 172 |
+
using Helper = eigen_tensor_helper<Type>;
|
| 173 |
+
static constexpr auto temp_name = get_tensor_descriptor<Type, false>::value;
|
| 174 |
+
PYBIND11_TYPE_CASTER(Type, temp_name);
|
| 175 |
+
|
| 176 |
+
bool load(handle src, bool convert) {
|
| 177 |
+
if (!convert) {
|
| 178 |
+
if (!isinstance<array>(src)) {
|
| 179 |
+
return false;
|
| 180 |
+
}
|
| 181 |
+
array temp = array::ensure(src);
|
| 182 |
+
if (!temp) {
|
| 183 |
+
return false;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
if (!temp.dtype().is(dtype::of<typename Type::Scalar>())) {
|
| 187 |
+
return false;
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()> arr(
|
| 192 |
+
reinterpret_borrow<object>(src));
|
| 193 |
+
|
| 194 |
+
if (arr.ndim() != Type::NumIndices) {
|
| 195 |
+
return false;
|
| 196 |
+
}
|
| 197 |
+
auto shape = get_shape_for_array<typename Type::Index, Type::NumIndices>(arr);
|
| 198 |
+
|
| 199 |
+
if (!Helper::is_correct_shape(shape)) {
|
| 200 |
+
return false;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
#if EIGEN_VERSION_AT_LEAST(3, 4, 0)
|
| 204 |
+
auto data_pointer = arr.data();
|
| 205 |
+
#else
|
| 206 |
+
// Handle Eigen bug
|
| 207 |
+
auto data_pointer = const_cast<typename Type::Scalar *>(arr.data());
|
| 208 |
+
#endif
|
| 209 |
+
|
| 210 |
+
if (is_tensor_aligned(arr.data())) {
|
| 211 |
+
value = Eigen::TensorMap<const Type, Eigen::Aligned>(data_pointer, shape);
|
| 212 |
+
} else {
|
| 213 |
+
value = Eigen::TensorMap<const Type>(data_pointer, shape);
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
return true;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
static handle cast(Type &&src, return_value_policy policy, handle parent) {
|
| 220 |
+
if (policy == return_value_policy::reference
|
| 221 |
+
|| policy == return_value_policy::reference_internal) {
|
| 222 |
+
pybind11_fail("Cannot use a reference return value policy for an rvalue");
|
| 223 |
+
}
|
| 224 |
+
return cast_impl(&src, return_value_policy::move, parent);
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
static handle cast(const Type &&src, return_value_policy policy, handle parent) {
|
| 228 |
+
if (policy == return_value_policy::reference
|
| 229 |
+
|| policy == return_value_policy::reference_internal) {
|
| 230 |
+
pybind11_fail("Cannot use a reference return value policy for an rvalue");
|
| 231 |
+
}
|
| 232 |
+
return cast_impl(&src, return_value_policy::move, parent);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
static handle cast(Type &src, return_value_policy policy, handle parent) {
|
| 236 |
+
if (policy == return_value_policy::automatic
|
| 237 |
+
|| policy == return_value_policy::automatic_reference) {
|
| 238 |
+
policy = return_value_policy::copy;
|
| 239 |
+
}
|
| 240 |
+
return cast_impl(&src, policy, parent);
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
static handle cast(const Type &src, return_value_policy policy, handle parent) {
|
| 244 |
+
if (policy == return_value_policy::automatic
|
| 245 |
+
|| policy == return_value_policy::automatic_reference) {
|
| 246 |
+
policy = return_value_policy::copy;
|
| 247 |
+
}
|
| 248 |
+
return cast(&src, policy, parent);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
static handle cast(Type *src, return_value_policy policy, handle parent) {
|
| 252 |
+
if (policy == return_value_policy::automatic) {
|
| 253 |
+
policy = return_value_policy::take_ownership;
|
| 254 |
+
} else if (policy == return_value_policy::automatic_reference) {
|
| 255 |
+
policy = return_value_policy::reference;
|
| 256 |
+
}
|
| 257 |
+
return cast_impl(src, policy, parent);
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
static handle cast(const Type *src, return_value_policy policy, handle parent) {
|
| 261 |
+
if (policy == return_value_policy::automatic) {
|
| 262 |
+
policy = return_value_policy::take_ownership;
|
| 263 |
+
} else if (policy == return_value_policy::automatic_reference) {
|
| 264 |
+
policy = return_value_policy::reference;
|
| 265 |
+
}
|
| 266 |
+
return cast_impl(src, policy, parent);
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
template <typename C>
|
| 270 |
+
static handle cast_impl(C *src, return_value_policy policy, handle parent) {
|
| 271 |
+
object parent_object;
|
| 272 |
+
bool writeable = false;
|
| 273 |
+
switch (policy) {
|
| 274 |
+
case return_value_policy::move:
|
| 275 |
+
if (std::is_const<C>::value) {
|
| 276 |
+
pybind11_fail("Cannot move from a constant reference");
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
src = Helper::alloc(std::move(*src));
|
| 280 |
+
|
| 281 |
+
parent_object
|
| 282 |
+
= capsule(src, [](void *ptr) { Helper::free(reinterpret_cast<Type *>(ptr)); });
|
| 283 |
+
writeable = true;
|
| 284 |
+
break;
|
| 285 |
+
|
| 286 |
+
case return_value_policy::take_ownership:
|
| 287 |
+
if (std::is_const<C>::value) {
|
| 288 |
+
// This cast is ugly, and might be UB in some cases, but we don't have an
|
| 289 |
+
// alternative here as we must free that memory
|
| 290 |
+
Helper::free(const_cast<Type *>(src));
|
| 291 |
+
pybind11_fail("Cannot take ownership of a const reference");
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
parent_object
|
| 295 |
+
= capsule(src, [](void *ptr) { Helper::free(reinterpret_cast<Type *>(ptr)); });
|
| 296 |
+
writeable = true;
|
| 297 |
+
break;
|
| 298 |
+
|
| 299 |
+
case return_value_policy::copy:
|
| 300 |
+
writeable = true;
|
| 301 |
+
break;
|
| 302 |
+
|
| 303 |
+
case return_value_policy::reference:
|
| 304 |
+
parent_object = none();
|
| 305 |
+
writeable = !std::is_const<C>::value;
|
| 306 |
+
break;
|
| 307 |
+
|
| 308 |
+
case return_value_policy::reference_internal:
|
| 309 |
+
// Default should do the right thing
|
| 310 |
+
if (!parent) {
|
| 311 |
+
pybind11_fail("Cannot use reference internal when there is no parent");
|
| 312 |
+
}
|
| 313 |
+
parent_object = reinterpret_borrow<object>(parent);
|
| 314 |
+
writeable = !std::is_const<C>::value;
|
| 315 |
+
break;
|
| 316 |
+
|
| 317 |
+
default:
|
| 318 |
+
pybind11_fail("pybind11 bug in eigen.h, please file a bug report");
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
auto result = array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
|
| 322 |
+
convert_dsizes_to_vector(Helper::get_shape(*src)), src->data(), parent_object);
|
| 323 |
+
|
| 324 |
+
if (!writeable) {
|
| 325 |
+
array_proxy(result.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
return result.release();
|
| 329 |
+
}
|
| 330 |
+
};
|
| 331 |
+
|
| 332 |
+
template <typename StoragePointerType,
|
| 333 |
+
bool needs_writeable,
|
| 334 |
+
enable_if_t<!needs_writeable, bool> = true>
|
| 335 |
+
StoragePointerType get_array_data_for_type(array &arr) {
|
| 336 |
+
#if EIGEN_VERSION_AT_LEAST(3, 4, 0)
|
| 337 |
+
return reinterpret_cast<StoragePointerType>(arr.data());
|
| 338 |
+
#else
|
| 339 |
+
// Handle Eigen bug
|
| 340 |
+
return reinterpret_cast<StoragePointerType>(const_cast<void *>(arr.data()));
|
| 341 |
+
#endif
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
template <typename StoragePointerType,
|
| 345 |
+
bool needs_writeable,
|
| 346 |
+
enable_if_t<needs_writeable, bool> = true>
|
| 347 |
+
StoragePointerType get_array_data_for_type(array &arr) {
|
| 348 |
+
return reinterpret_cast<StoragePointerType>(arr.mutable_data());
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
template <typename T, typename = void>
|
| 352 |
+
struct get_storage_pointer_type;
|
| 353 |
+
|
| 354 |
+
template <typename MapType>
|
| 355 |
+
struct get_storage_pointer_type<MapType, void_t<typename MapType::StoragePointerType>> {
|
| 356 |
+
using SPT = typename MapType::StoragePointerType;
|
| 357 |
+
};
|
| 358 |
+
|
| 359 |
+
template <typename MapType>
|
| 360 |
+
struct get_storage_pointer_type<MapType, void_t<typename MapType::PointerArgType>> {
|
| 361 |
+
using SPT = typename MapType::PointerArgType;
|
| 362 |
+
};
|
| 363 |
+
|
| 364 |
+
template <typename Type, int Options>
|
| 365 |
+
struct type_caster<Eigen::TensorMap<Type, Options>,
|
| 366 |
+
typename eigen_tensor_helper<remove_cv_t<Type>>::ValidType> {
|
| 367 |
+
static_assert(!std::is_pointer<typename Type::Scalar>::value,
|
| 368 |
+
PYBIND11_EIGEN_MESSAGE_POINTER_TYPES_ARE_NOT_SUPPORTED);
|
| 369 |
+
using MapType = Eigen::TensorMap<Type, Options>;
|
| 370 |
+
using Helper = eigen_tensor_helper<remove_cv_t<Type>>;
|
| 371 |
+
|
| 372 |
+
bool load(handle src, bool /*convert*/) {
|
| 373 |
+
// Note that we have a lot more checks here as we want to make sure to avoid copies
|
| 374 |
+
if (!isinstance<array>(src)) {
|
| 375 |
+
return false;
|
| 376 |
+
}
|
| 377 |
+
auto arr = reinterpret_borrow<array>(src);
|
| 378 |
+
if ((arr.flags() & compute_array_flag_from_tensor<Type>()) == 0) {
|
| 379 |
+
return false;
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
if (!arr.dtype().is(dtype::of<typename Type::Scalar>())) {
|
| 383 |
+
return false;
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
if (arr.ndim() != Type::NumIndices) {
|
| 387 |
+
return false;
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
constexpr bool is_aligned = (Options & Eigen::Aligned) != 0;
|
| 391 |
+
|
| 392 |
+
if (is_aligned && !is_tensor_aligned(arr.data())) {
|
| 393 |
+
return false;
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
auto shape = get_shape_for_array<typename Type::Index, Type::NumIndices>(arr);
|
| 397 |
+
|
| 398 |
+
if (!Helper::is_correct_shape(shape)) {
|
| 399 |
+
return false;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
if (needs_writeable && !arr.writeable()) {
|
| 403 |
+
return false;
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
auto result = get_array_data_for_type<typename get_storage_pointer_type<MapType>::SPT,
|
| 407 |
+
needs_writeable>(arr);
|
| 408 |
+
|
| 409 |
+
value.reset(new MapType(std::move(result), std::move(shape)));
|
| 410 |
+
|
| 411 |
+
return true;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
static handle cast(MapType &&src, return_value_policy policy, handle parent) {
|
| 415 |
+
return cast_impl(&src, policy, parent);
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
static handle cast(const MapType &&src, return_value_policy policy, handle parent) {
|
| 419 |
+
return cast_impl(&src, policy, parent);
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
static handle cast(MapType &src, return_value_policy policy, handle parent) {
|
| 423 |
+
if (policy == return_value_policy::automatic
|
| 424 |
+
|| policy == return_value_policy::automatic_reference) {
|
| 425 |
+
policy = return_value_policy::copy;
|
| 426 |
+
}
|
| 427 |
+
return cast_impl(&src, policy, parent);
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
static handle cast(const MapType &src, return_value_policy policy, handle parent) {
|
| 431 |
+
if (policy == return_value_policy::automatic
|
| 432 |
+
|| policy == return_value_policy::automatic_reference) {
|
| 433 |
+
policy = return_value_policy::copy;
|
| 434 |
+
}
|
| 435 |
+
return cast(&src, policy, parent);
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
static handle cast(MapType *src, return_value_policy policy, handle parent) {
|
| 439 |
+
if (policy == return_value_policy::automatic) {
|
| 440 |
+
policy = return_value_policy::take_ownership;
|
| 441 |
+
} else if (policy == return_value_policy::automatic_reference) {
|
| 442 |
+
policy = return_value_policy::reference;
|
| 443 |
+
}
|
| 444 |
+
return cast_impl(src, policy, parent);
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
static handle cast(const MapType *src, return_value_policy policy, handle parent) {
|
| 448 |
+
if (policy == return_value_policy::automatic) {
|
| 449 |
+
policy = return_value_policy::take_ownership;
|
| 450 |
+
} else if (policy == return_value_policy::automatic_reference) {
|
| 451 |
+
policy = return_value_policy::reference;
|
| 452 |
+
}
|
| 453 |
+
return cast_impl(src, policy, parent);
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
template <typename C>
|
| 457 |
+
static handle cast_impl(C *src, return_value_policy policy, handle parent) {
|
| 458 |
+
object parent_object;
|
| 459 |
+
constexpr bool writeable = !std::is_const<C>::value;
|
| 460 |
+
switch (policy) {
|
| 461 |
+
case return_value_policy::reference:
|
| 462 |
+
parent_object = none();
|
| 463 |
+
break;
|
| 464 |
+
|
| 465 |
+
case return_value_policy::reference_internal:
|
| 466 |
+
// Default should do the right thing
|
| 467 |
+
if (!parent) {
|
| 468 |
+
pybind11_fail("Cannot use reference internal when there is no parent");
|
| 469 |
+
}
|
| 470 |
+
parent_object = reinterpret_borrow<object>(parent);
|
| 471 |
+
break;
|
| 472 |
+
|
| 473 |
+
default:
|
| 474 |
+
// move, take_ownership don't make any sense for a ref/map:
|
| 475 |
+
pybind11_fail("Invalid return_value_policy for Eigen Map type, must be either "
|
| 476 |
+
"reference or reference_internal");
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
auto result = array_t<typename Type::Scalar, compute_array_flag_from_tensor<Type>()>(
|
| 480 |
+
convert_dsizes_to_vector(Helper::get_shape(*src)),
|
| 481 |
+
src->data(),
|
| 482 |
+
std::move(parent_object));
|
| 483 |
+
|
| 484 |
+
if (!writeable) {
|
| 485 |
+
array_proxy(result.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
return result.release();
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
#if EIGEN_VERSION_AT_LEAST(3, 4, 0)
|
| 492 |
+
|
| 493 |
+
static constexpr bool needs_writeable = !std::is_const<typename std::remove_pointer<
|
| 494 |
+
typename get_storage_pointer_type<MapType>::SPT>::type>::value;
|
| 495 |
+
#else
|
| 496 |
+
// Handle Eigen bug
|
| 497 |
+
static constexpr bool needs_writeable = !std::is_const<Type>::value;
|
| 498 |
+
#endif
|
| 499 |
+
|
| 500 |
+
protected:
|
| 501 |
+
// TODO: Move to std::optional once std::optional has more support
|
| 502 |
+
std::unique_ptr<MapType> value;
|
| 503 |
+
|
| 504 |
+
public:
|
| 505 |
+
static constexpr auto name = get_tensor_descriptor<Type, true, needs_writeable>::value;
|
| 506 |
+
explicit operator MapType *() { return value.get(); }
|
| 507 |
+
explicit operator MapType &() { return *value; }
|
| 508 |
+
explicit operator MapType &&() && { return std::move(*value); }
|
| 509 |
+
|
| 510 |
+
template <typename T_>
|
| 511 |
+
using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>;
|
| 512 |
+
};
|
| 513 |
+
|
| 514 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 515 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/embed.h
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/embed.h: Support for embedding the interpreter
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "pybind11.h"
|
| 13 |
+
#include "eval.h"
|
| 14 |
+
|
| 15 |
+
#include <memory>
|
| 16 |
+
#include <vector>
|
| 17 |
+
|
| 18 |
+
#if defined(PYPY_VERSION)
|
| 19 |
+
# error Embedding the interpreter is not supported with PyPy
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
#define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
|
| 23 |
+
extern "C" PyObject *pybind11_init_impl_##name(); \
|
| 24 |
+
extern "C" PyObject *pybind11_init_impl_##name() { return pybind11_init_wrapper_##name(); }
|
| 25 |
+
|
| 26 |
+
/** \rst
|
| 27 |
+
Add a new module to the table of builtins for the interpreter. Must be
|
| 28 |
+
defined in global scope. The first macro parameter is the name of the
|
| 29 |
+
module (without quotes). The second parameter is the variable which will
|
| 30 |
+
be used as the interface to add functions and classes to the module.
|
| 31 |
+
|
| 32 |
+
.. code-block:: cpp
|
| 33 |
+
|
| 34 |
+
PYBIND11_EMBEDDED_MODULE(example, m) {
|
| 35 |
+
// ... initialize functions and classes here
|
| 36 |
+
m.def("foo", []() {
|
| 37 |
+
return "Hello, World!";
|
| 38 |
+
});
|
| 39 |
+
}
|
| 40 |
+
\endrst */
|
| 41 |
+
#define PYBIND11_EMBEDDED_MODULE(name, variable) \
|
| 42 |
+
static ::pybind11::module_::module_def PYBIND11_CONCAT(pybind11_module_def_, name); \
|
| 43 |
+
static void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ &); \
|
| 44 |
+
static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
|
| 45 |
+
auto m = ::pybind11::module_::create_extension_module( \
|
| 46 |
+
PYBIND11_TOSTRING(name), nullptr, &PYBIND11_CONCAT(pybind11_module_def_, name)); \
|
| 47 |
+
try { \
|
| 48 |
+
PYBIND11_CONCAT(pybind11_init_, name)(m); \
|
| 49 |
+
return m.ptr(); \
|
| 50 |
+
} \
|
| 51 |
+
PYBIND11_CATCH_INIT_EXCEPTIONS \
|
| 52 |
+
} \
|
| 53 |
+
PYBIND11_EMBEDDED_MODULE_IMPL(name) \
|
| 54 |
+
::pybind11::detail::embedded_module PYBIND11_CONCAT(pybind11_module_, name)( \
|
| 55 |
+
PYBIND11_TOSTRING(name), PYBIND11_CONCAT(pybind11_init_impl_, name)); \
|
| 56 |
+
void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ \
|
| 57 |
+
& variable) // NOLINT(bugprone-macro-parentheses)
|
| 58 |
+
|
| 59 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 60 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 61 |
+
|
| 62 |
+
/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks.
|
| 63 |
+
struct embedded_module {
|
| 64 |
+
using init_t = PyObject *(*) ();
|
| 65 |
+
embedded_module(const char *name, init_t init) {
|
| 66 |
+
if (Py_IsInitialized() != 0) {
|
| 67 |
+
pybind11_fail("Can't add new modules after the interpreter has been initialized");
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
auto result = PyImport_AppendInittab(name, init);
|
| 71 |
+
if (result == -1) {
|
| 72 |
+
pybind11_fail("Insufficient memory to add a new module");
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
};
|
| 76 |
+
|
| 77 |
+
struct wide_char_arg_deleter {
|
| 78 |
+
void operator()(wchar_t *ptr) const {
|
| 79 |
+
// API docs: https://docs.python.org/3/c-api/sys.html#c.Py_DecodeLocale
|
| 80 |
+
PyMem_RawFree(ptr);
|
| 81 |
+
}
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
inline wchar_t *widen_chars(const char *safe_arg) {
|
| 85 |
+
wchar_t *widened_arg = Py_DecodeLocale(safe_arg, nullptr);
|
| 86 |
+
return widened_arg;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
inline void precheck_interpreter() {
|
| 90 |
+
if (Py_IsInitialized() != 0) {
|
| 91 |
+
pybind11_fail("The interpreter is already running");
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
#if !defined(PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX)
|
| 96 |
+
# define PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX (0x03080000)
|
| 97 |
+
#endif
|
| 98 |
+
|
| 99 |
+
#if PY_VERSION_HEX < PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
|
| 100 |
+
inline void initialize_interpreter_pre_pyconfig(bool init_signal_handlers,
|
| 101 |
+
int argc,
|
| 102 |
+
const char *const *argv,
|
| 103 |
+
bool add_program_dir_to_path) {
|
| 104 |
+
detail::precheck_interpreter();
|
| 105 |
+
Py_InitializeEx(init_signal_handlers ? 1 : 0);
|
| 106 |
+
|
| 107 |
+
// Before it was special-cased in python 3.8, passing an empty or null argv
|
| 108 |
+
// caused a segfault, so we have to reimplement the special case ourselves.
|
| 109 |
+
bool special_case = (argv == nullptr || argc <= 0);
|
| 110 |
+
|
| 111 |
+
const char *const empty_argv[]{"\0"};
|
| 112 |
+
const char *const *safe_argv = special_case ? empty_argv : argv;
|
| 113 |
+
if (special_case) {
|
| 114 |
+
argc = 1;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
auto argv_size = static_cast<size_t>(argc);
|
| 118 |
+
// SetArgv* on python 3 takes wchar_t, so we have to convert.
|
| 119 |
+
std::unique_ptr<wchar_t *[]> widened_argv(new wchar_t *[argv_size]);
|
| 120 |
+
std::vector<std::unique_ptr<wchar_t[], detail::wide_char_arg_deleter>> widened_argv_entries;
|
| 121 |
+
widened_argv_entries.reserve(argv_size);
|
| 122 |
+
for (size_t ii = 0; ii < argv_size; ++ii) {
|
| 123 |
+
widened_argv_entries.emplace_back(detail::widen_chars(safe_argv[ii]));
|
| 124 |
+
if (!widened_argv_entries.back()) {
|
| 125 |
+
// A null here indicates a character-encoding failure or the python
|
| 126 |
+
// interpreter out of memory. Give up.
|
| 127 |
+
return;
|
| 128 |
+
}
|
| 129 |
+
widened_argv[ii] = widened_argv_entries.back().get();
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
auto *pysys_argv = widened_argv.get();
|
| 133 |
+
|
| 134 |
+
PySys_SetArgvEx(argc, pysys_argv, static_cast<int>(add_program_dir_to_path));
|
| 135 |
+
}
|
| 136 |
+
#endif
|
| 137 |
+
|
| 138 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 139 |
+
|
| 140 |
+
#if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
|
| 141 |
+
inline void initialize_interpreter(PyConfig *config,
|
| 142 |
+
int argc = 0,
|
| 143 |
+
const char *const *argv = nullptr,
|
| 144 |
+
bool add_program_dir_to_path = true) {
|
| 145 |
+
detail::precheck_interpreter();
|
| 146 |
+
PyStatus status = PyConfig_SetBytesArgv(config, argc, const_cast<char *const *>(argv));
|
| 147 |
+
if (PyStatus_Exception(status) != 0) {
|
| 148 |
+
// A failure here indicates a character-encoding failure or the python
|
| 149 |
+
// interpreter out of memory. Give up.
|
| 150 |
+
PyConfig_Clear(config);
|
| 151 |
+
throw std::runtime_error(PyStatus_IsError(status) != 0 ? status.err_msg
|
| 152 |
+
: "Failed to prepare CPython");
|
| 153 |
+
}
|
| 154 |
+
status = Py_InitializeFromConfig(config);
|
| 155 |
+
if (PyStatus_Exception(status) != 0) {
|
| 156 |
+
PyConfig_Clear(config);
|
| 157 |
+
throw std::runtime_error(PyStatus_IsError(status) != 0 ? status.err_msg
|
| 158 |
+
: "Failed to init CPython");
|
| 159 |
+
}
|
| 160 |
+
if (add_program_dir_to_path) {
|
| 161 |
+
PyRun_SimpleString("import sys, os.path; "
|
| 162 |
+
"sys.path.insert(0, "
|
| 163 |
+
"os.path.abspath(os.path.dirname(sys.argv[0])) "
|
| 164 |
+
"if sys.argv and os.path.exists(sys.argv[0]) else '')");
|
| 165 |
+
}
|
| 166 |
+
PyConfig_Clear(config);
|
| 167 |
+
}
|
| 168 |
+
#endif
|
| 169 |
+
|
| 170 |
+
/** \rst
|
| 171 |
+
Initialize the Python interpreter. No other pybind11 or CPython API functions can be
|
| 172 |
+
called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
|
| 173 |
+
optional `init_signal_handlers` parameter can be used to skip the registration of
|
| 174 |
+
signal handlers (see the `Python documentation`_ for details). Calling this function
|
| 175 |
+
again after the interpreter has already been initialized is a fatal error.
|
| 176 |
+
|
| 177 |
+
If initializing the Python interpreter fails, then the program is terminated. (This
|
| 178 |
+
is controlled by the CPython runtime and is an exception to pybind11's normal behavior
|
| 179 |
+
of throwing exceptions on errors.)
|
| 180 |
+
|
| 181 |
+
The remaining optional parameters, `argc`, `argv`, and `add_program_dir_to_path` are
|
| 182 |
+
used to populate ``sys.argv`` and ``sys.path``.
|
| 183 |
+
See the |PySys_SetArgvEx documentation|_ for details.
|
| 184 |
+
|
| 185 |
+
.. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
|
| 186 |
+
.. |PySys_SetArgvEx documentation| replace:: ``PySys_SetArgvEx`` documentation
|
| 187 |
+
.. _PySys_SetArgvEx documentation: https://docs.python.org/3/c-api/init.html#c.PySys_SetArgvEx
|
| 188 |
+
\endrst */
|
| 189 |
+
inline void initialize_interpreter(bool init_signal_handlers = true,
|
| 190 |
+
int argc = 0,
|
| 191 |
+
const char *const *argv = nullptr,
|
| 192 |
+
bool add_program_dir_to_path = true) {
|
| 193 |
+
#if PY_VERSION_HEX < PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
|
| 194 |
+
detail::initialize_interpreter_pre_pyconfig(
|
| 195 |
+
init_signal_handlers, argc, argv, add_program_dir_to_path);
|
| 196 |
+
#else
|
| 197 |
+
PyConfig config;
|
| 198 |
+
PyConfig_InitPythonConfig(&config);
|
| 199 |
+
// See PR #4473 for background
|
| 200 |
+
config.parse_argv = 0;
|
| 201 |
+
|
| 202 |
+
config.install_signal_handlers = init_signal_handlers ? 1 : 0;
|
| 203 |
+
initialize_interpreter(&config, argc, argv, add_program_dir_to_path);
|
| 204 |
+
#endif
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
/** \rst
|
| 208 |
+
Shut down the Python interpreter. No pybind11 or CPython API functions can be called
|
| 209 |
+
after this. In addition, pybind11 objects must not outlive the interpreter:
|
| 210 |
+
|
| 211 |
+
.. code-block:: cpp
|
| 212 |
+
|
| 213 |
+
{ // BAD
|
| 214 |
+
py::initialize_interpreter();
|
| 215 |
+
auto hello = py::str("Hello, World!");
|
| 216 |
+
py::finalize_interpreter();
|
| 217 |
+
} // <-- BOOM, hello's destructor is called after interpreter shutdown
|
| 218 |
+
|
| 219 |
+
{ // GOOD
|
| 220 |
+
py::initialize_interpreter();
|
| 221 |
+
{ // scoped
|
| 222 |
+
auto hello = py::str("Hello, World!");
|
| 223 |
+
} // <-- OK, hello is cleaned up properly
|
| 224 |
+
py::finalize_interpreter();
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
{ // BETTER
|
| 228 |
+
py::scoped_interpreter guard{};
|
| 229 |
+
auto hello = py::str("Hello, World!");
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
.. warning::
|
| 233 |
+
|
| 234 |
+
The interpreter can be restarted by calling `initialize_interpreter` again.
|
| 235 |
+
Modules created using pybind11 can be safely re-initialized. However, Python
|
| 236 |
+
itself cannot completely unload binary extension modules and there are several
|
| 237 |
+
caveats with regard to interpreter restarting. All the details can be found
|
| 238 |
+
in the CPython documentation. In short, not all interpreter memory may be
|
| 239 |
+
freed, either due to reference cycles or user-created global data.
|
| 240 |
+
|
| 241 |
+
\endrst */
|
| 242 |
+
inline void finalize_interpreter() {
|
| 243 |
+
// Get the internals pointer (without creating it if it doesn't exist). It's possible for the
|
| 244 |
+
// internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
|
| 245 |
+
// during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
|
| 246 |
+
detail::internals **internals_ptr_ptr = detail::get_internals_pp();
|
| 247 |
+
// It could also be stashed in state_dict, so look there too:
|
| 248 |
+
if (object internals_obj
|
| 249 |
+
= get_internals_obj_from_state_dict(detail::get_python_state_dict())) {
|
| 250 |
+
internals_ptr_ptr = detail::get_internals_pp_from_capsule(internals_obj);
|
| 251 |
+
}
|
| 252 |
+
// Local internals contains data managed by the current interpreter, so we must clear them to
|
| 253 |
+
// avoid undefined behaviors when initializing another interpreter
|
| 254 |
+
detail::get_local_internals().registered_types_cpp.clear();
|
| 255 |
+
detail::get_local_internals().registered_exception_translators.clear();
|
| 256 |
+
|
| 257 |
+
Py_Finalize();
|
| 258 |
+
|
| 259 |
+
if (internals_ptr_ptr) {
|
| 260 |
+
delete *internals_ptr_ptr;
|
| 261 |
+
*internals_ptr_ptr = nullptr;
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
/** \rst
|
| 266 |
+
Scope guard version of `initialize_interpreter` and `finalize_interpreter`.
|
| 267 |
+
This a move-only guard and only a single instance can exist.
|
| 268 |
+
|
| 269 |
+
See `initialize_interpreter` for a discussion of its constructor arguments.
|
| 270 |
+
|
| 271 |
+
.. code-block:: cpp
|
| 272 |
+
|
| 273 |
+
#include <pybind11/embed.h>
|
| 274 |
+
|
| 275 |
+
int main() {
|
| 276 |
+
py::scoped_interpreter guard{};
|
| 277 |
+
py::print(Hello, World!);
|
| 278 |
+
} // <-- interpreter shutdown
|
| 279 |
+
\endrst */
|
| 280 |
+
class scoped_interpreter {
|
| 281 |
+
public:
|
| 282 |
+
explicit scoped_interpreter(bool init_signal_handlers = true,
|
| 283 |
+
int argc = 0,
|
| 284 |
+
const char *const *argv = nullptr,
|
| 285 |
+
bool add_program_dir_to_path = true) {
|
| 286 |
+
initialize_interpreter(init_signal_handlers, argc, argv, add_program_dir_to_path);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
#if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
|
| 290 |
+
explicit scoped_interpreter(PyConfig *config,
|
| 291 |
+
int argc = 0,
|
| 292 |
+
const char *const *argv = nullptr,
|
| 293 |
+
bool add_program_dir_to_path = true) {
|
| 294 |
+
initialize_interpreter(config, argc, argv, add_program_dir_to_path);
|
| 295 |
+
}
|
| 296 |
+
#endif
|
| 297 |
+
|
| 298 |
+
scoped_interpreter(const scoped_interpreter &) = delete;
|
| 299 |
+
scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; }
|
| 300 |
+
scoped_interpreter &operator=(const scoped_interpreter &) = delete;
|
| 301 |
+
scoped_interpreter &operator=(scoped_interpreter &&) = delete;
|
| 302 |
+
|
| 303 |
+
~scoped_interpreter() {
|
| 304 |
+
if (is_valid) {
|
| 305 |
+
finalize_interpreter();
|
| 306 |
+
}
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
private:
|
| 310 |
+
bool is_valid = true;
|
| 311 |
+
};
|
| 312 |
+
|
| 313 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/eval.h
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/eval.h: Support for evaluating Python expressions and statements
|
| 3 |
+
from strings and files
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2016 Klemens Morgenstern <klemens.morgenstern@ed-chemnitz.de> and
|
| 6 |
+
Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 7 |
+
|
| 8 |
+
All rights reserved. Use of this source code is governed by a
|
| 9 |
+
BSD-style license that can be found in the LICENSE file.
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#pragma once
|
| 13 |
+
|
| 14 |
+
#include "pybind11.h"
|
| 15 |
+
|
| 16 |
+
#include <utility>
|
| 17 |
+
|
| 18 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 19 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 20 |
+
|
| 21 |
+
inline void ensure_builtins_in_globals(object &global) {
|
| 22 |
+
#if defined(PYPY_VERSION) || PY_VERSION_HEX < 0x03080000
|
| 23 |
+
// Running exec and eval adds `builtins` module under `__builtins__` key to
|
| 24 |
+
// globals if not yet present. Python 3.8 made PyRun_String behave
|
| 25 |
+
// similarly. Let's also do that for older versions, for consistency. This
|
| 26 |
+
// was missing from PyPy3.8 7.3.7.
|
| 27 |
+
if (!global.contains("__builtins__"))
|
| 28 |
+
global["__builtins__"] = module_::import(PYBIND11_BUILTINS_MODULE);
|
| 29 |
+
#else
|
| 30 |
+
(void) global;
|
| 31 |
+
#endif
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 35 |
+
|
| 36 |
+
enum eval_mode {
|
| 37 |
+
/// Evaluate a string containing an isolated expression
|
| 38 |
+
eval_expr,
|
| 39 |
+
|
| 40 |
+
/// Evaluate a string containing a single statement. Returns \c none
|
| 41 |
+
eval_single_statement,
|
| 42 |
+
|
| 43 |
+
/// Evaluate a string containing a sequence of statement. Returns \c none
|
| 44 |
+
eval_statements
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
template <eval_mode mode = eval_expr>
|
| 48 |
+
object eval(const str &expr, object global = globals(), object local = object()) {
|
| 49 |
+
if (!local) {
|
| 50 |
+
local = global;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
detail::ensure_builtins_in_globals(global);
|
| 54 |
+
|
| 55 |
+
/* PyRun_String does not accept a PyObject / encoding specifier,
|
| 56 |
+
this seems to be the only alternative */
|
| 57 |
+
std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr;
|
| 58 |
+
|
| 59 |
+
int start = 0;
|
| 60 |
+
switch (mode) {
|
| 61 |
+
case eval_expr:
|
| 62 |
+
start = Py_eval_input;
|
| 63 |
+
break;
|
| 64 |
+
case eval_single_statement:
|
| 65 |
+
start = Py_single_input;
|
| 66 |
+
break;
|
| 67 |
+
case eval_statements:
|
| 68 |
+
start = Py_file_input;
|
| 69 |
+
break;
|
| 70 |
+
default:
|
| 71 |
+
pybind11_fail("invalid evaluation mode");
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
|
| 75 |
+
if (!result) {
|
| 76 |
+
throw error_already_set();
|
| 77 |
+
}
|
| 78 |
+
return reinterpret_steal<object>(result);
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template <eval_mode mode = eval_expr, size_t N>
|
| 82 |
+
object eval(const char (&s)[N], object global = globals(), object local = object()) {
|
| 83 |
+
/* Support raw string literals by removing common leading whitespace */
|
| 84 |
+
auto expr = (s[0] == '\n') ? str(module_::import("textwrap").attr("dedent")(s)) : str(s);
|
| 85 |
+
return eval<mode>(expr, std::move(global), std::move(local));
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
inline void exec(const str &expr, object global = globals(), object local = object()) {
|
| 89 |
+
eval<eval_statements>(expr, std::move(global), std::move(local));
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
template <size_t N>
|
| 93 |
+
void exec(const char (&s)[N], object global = globals(), object local = object()) {
|
| 94 |
+
eval<eval_statements>(s, std::move(global), std::move(local));
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
#if defined(PYPY_VERSION)
|
| 98 |
+
template <eval_mode mode = eval_statements>
|
| 99 |
+
object eval_file(str, object, object) {
|
| 100 |
+
pybind11_fail("eval_file not supported in PyPy3. Use eval");
|
| 101 |
+
}
|
| 102 |
+
template <eval_mode mode = eval_statements>
|
| 103 |
+
object eval_file(str, object) {
|
| 104 |
+
pybind11_fail("eval_file not supported in PyPy3. Use eval");
|
| 105 |
+
}
|
| 106 |
+
template <eval_mode mode = eval_statements>
|
| 107 |
+
object eval_file(str) {
|
| 108 |
+
pybind11_fail("eval_file not supported in PyPy3. Use eval");
|
| 109 |
+
}
|
| 110 |
+
#else
|
| 111 |
+
template <eval_mode mode = eval_statements>
|
| 112 |
+
object eval_file(str fname, object global = globals(), object local = object()) {
|
| 113 |
+
if (!local) {
|
| 114 |
+
local = global;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
detail::ensure_builtins_in_globals(global);
|
| 118 |
+
|
| 119 |
+
int start = 0;
|
| 120 |
+
switch (mode) {
|
| 121 |
+
case eval_expr:
|
| 122 |
+
start = Py_eval_input;
|
| 123 |
+
break;
|
| 124 |
+
case eval_single_statement:
|
| 125 |
+
start = Py_single_input;
|
| 126 |
+
break;
|
| 127 |
+
case eval_statements:
|
| 128 |
+
start = Py_file_input;
|
| 129 |
+
break;
|
| 130 |
+
default:
|
| 131 |
+
pybind11_fail("invalid evaluation mode");
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
int closeFile = 1;
|
| 135 |
+
std::string fname_str = (std::string) fname;
|
| 136 |
+
FILE *f = _Py_fopen_obj(fname.ptr(), "r");
|
| 137 |
+
if (!f) {
|
| 138 |
+
PyErr_Clear();
|
| 139 |
+
pybind11_fail("File \"" + fname_str + "\" could not be opened!");
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
if (!global.contains("__file__")) {
|
| 143 |
+
global["__file__"] = std::move(fname);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
PyObject *result
|
| 147 |
+
= PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(), local.ptr(), closeFile);
|
| 148 |
+
|
| 149 |
+
if (!result) {
|
| 150 |
+
throw error_already_set();
|
| 151 |
+
}
|
| 152 |
+
return reinterpret_steal<object>(result);
|
| 153 |
+
}
|
| 154 |
+
#endif
|
| 155 |
+
|
| 156 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/functional.h
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/functional.h: std::function<> support
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#define PYBIND11_HAS_TYPE_CASTER_STD_FUNCTION_SPECIALIZATIONS
|
| 13 |
+
|
| 14 |
+
#include "pybind11.h"
|
| 15 |
+
|
| 16 |
+
#include <functional>
|
| 17 |
+
|
| 18 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 19 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 20 |
+
PYBIND11_NAMESPACE_BEGIN(type_caster_std_function_specializations)
|
| 21 |
+
|
| 22 |
+
// ensure GIL is held during functor destruction
|
| 23 |
+
struct func_handle {
|
| 24 |
+
function f;
|
| 25 |
+
#if !(defined(_MSC_VER) && _MSC_VER == 1916 && defined(PYBIND11_CPP17))
|
| 26 |
+
// This triggers a syntax error under very special conditions (very weird indeed).
|
| 27 |
+
explicit
|
| 28 |
+
#endif
|
| 29 |
+
func_handle(function &&f_) noexcept
|
| 30 |
+
: f(std::move(f_)) {
|
| 31 |
+
}
|
| 32 |
+
func_handle(const func_handle &f_) { operator=(f_); }
|
| 33 |
+
func_handle &operator=(const func_handle &f_) {
|
| 34 |
+
gil_scoped_acquire acq;
|
| 35 |
+
f = f_.f;
|
| 36 |
+
return *this;
|
| 37 |
+
}
|
| 38 |
+
~func_handle() {
|
| 39 |
+
gil_scoped_acquire acq;
|
| 40 |
+
function kill_f(std::move(f));
|
| 41 |
+
}
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
// to emulate 'move initialization capture' in C++11
|
| 45 |
+
struct func_wrapper_base {
|
| 46 |
+
func_handle hfunc;
|
| 47 |
+
explicit func_wrapper_base(func_handle &&hf) noexcept : hfunc(hf) {}
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
template <typename Return, typename... Args>
|
| 51 |
+
struct func_wrapper : func_wrapper_base {
|
| 52 |
+
using func_wrapper_base::func_wrapper_base;
|
| 53 |
+
Return operator()(Args... args) const {
|
| 54 |
+
gil_scoped_acquire acq;
|
| 55 |
+
// casts the returned object as a rvalue to the return type
|
| 56 |
+
return hfunc.f(std::forward<Args>(args)...).template cast<Return>();
|
| 57 |
+
}
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
PYBIND11_NAMESPACE_END(type_caster_std_function_specializations)
|
| 61 |
+
|
| 62 |
+
template <typename Return, typename... Args>
|
| 63 |
+
struct type_caster<std::function<Return(Args...)>> {
|
| 64 |
+
using type = std::function<Return(Args...)>;
|
| 65 |
+
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
|
| 66 |
+
using function_type = Return (*)(Args...);
|
| 67 |
+
|
| 68 |
+
public:
|
| 69 |
+
bool load(handle src, bool convert) {
|
| 70 |
+
if (src.is_none()) {
|
| 71 |
+
// Defer accepting None to other overloads (if we aren't in convert mode):
|
| 72 |
+
if (!convert) {
|
| 73 |
+
return false;
|
| 74 |
+
}
|
| 75 |
+
return true;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
if (!isinstance<function>(src)) {
|
| 79 |
+
return false;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
auto func = reinterpret_borrow<function>(src);
|
| 83 |
+
|
| 84 |
+
/*
|
| 85 |
+
When passing a C++ function as an argument to another C++
|
| 86 |
+
function via Python, every function call would normally involve
|
| 87 |
+
a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
|
| 88 |
+
Here, we try to at least detect the case where the function is
|
| 89 |
+
stateless (i.e. function pointer or lambda function without
|
| 90 |
+
captured variables), in which case the roundtrip can be avoided.
|
| 91 |
+
*/
|
| 92 |
+
if (auto cfunc = func.cpp_function()) {
|
| 93 |
+
auto *cfunc_self = PyCFunction_GET_SELF(cfunc.ptr());
|
| 94 |
+
if (cfunc_self == nullptr) {
|
| 95 |
+
PyErr_Clear();
|
| 96 |
+
} else if (isinstance<capsule>(cfunc_self)) {
|
| 97 |
+
auto c = reinterpret_borrow<capsule>(cfunc_self);
|
| 98 |
+
|
| 99 |
+
function_record *rec = nullptr;
|
| 100 |
+
// Check that we can safely reinterpret the capsule into a function_record
|
| 101 |
+
if (detail::is_function_record_capsule(c)) {
|
| 102 |
+
rec = c.get_pointer<function_record>();
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
while (rec != nullptr) {
|
| 106 |
+
if (rec->is_stateless
|
| 107 |
+
&& same_type(typeid(function_type),
|
| 108 |
+
*reinterpret_cast<const std::type_info *>(rec->data[1]))) {
|
| 109 |
+
struct capture {
|
| 110 |
+
function_type f;
|
| 111 |
+
};
|
| 112 |
+
value = ((capture *) &rec->data)->f;
|
| 113 |
+
return true;
|
| 114 |
+
}
|
| 115 |
+
rec = rec->next;
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
// PYPY segfaults here when passing builtin function like sum.
|
| 119 |
+
// Raising an fail exception here works to prevent the segfault, but only on gcc.
|
| 120 |
+
// See PR #1413 for full details
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
value = type_caster_std_function_specializations::func_wrapper<Return, Args...>(
|
| 124 |
+
type_caster_std_function_specializations::func_handle(std::move(func)));
|
| 125 |
+
return true;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
template <typename Func>
|
| 129 |
+
static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) {
|
| 130 |
+
if (!f_) {
|
| 131 |
+
return none().release();
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
auto result = f_.template target<function_type>();
|
| 135 |
+
if (result) {
|
| 136 |
+
return cpp_function(*result, policy).release();
|
| 137 |
+
}
|
| 138 |
+
return cpp_function(std::forward<Func>(f_), policy).release();
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
PYBIND11_TYPE_CASTER(type,
|
| 142 |
+
const_name("Callable[[")
|
| 143 |
+
+ ::pybind11::detail::concat(make_caster<Args>::name...)
|
| 144 |
+
+ const_name("], ") + make_caster<retval_type>::name
|
| 145 |
+
+ const_name("]"));
|
| 146 |
+
};
|
| 147 |
+
|
| 148 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 149 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/gil.h
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/gil.h: RAII helpers for managing the GIL
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "detail/common.h"
|
| 13 |
+
|
| 14 |
+
#include <cassert>
|
| 15 |
+
|
| 16 |
+
#if !defined(PYBIND11_SIMPLE_GIL_MANAGEMENT)
|
| 17 |
+
# include "detail/internals.h"
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 21 |
+
|
| 22 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 23 |
+
|
| 24 |
+
// forward declarations
|
| 25 |
+
PyThreadState *get_thread_state_unchecked();
|
| 26 |
+
|
| 27 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 28 |
+
|
| 29 |
+
#if !defined(PYBIND11_SIMPLE_GIL_MANAGEMENT)
|
| 30 |
+
|
| 31 |
+
/* The functions below essentially reproduce the PyGILState_* API using a RAII
|
| 32 |
+
* pattern, but there are a few important differences:
|
| 33 |
+
*
|
| 34 |
+
* 1. When acquiring the GIL from an non-main thread during the finalization
|
| 35 |
+
* phase, the GILState API blindly terminates the calling thread, which
|
| 36 |
+
* is often not what is wanted. This API does not do this.
|
| 37 |
+
*
|
| 38 |
+
* 2. The gil_scoped_release function can optionally cut the relationship
|
| 39 |
+
* of a PyThreadState and its associated thread, which allows moving it to
|
| 40 |
+
* another thread (this is a fairly rare/advanced use case).
|
| 41 |
+
*
|
| 42 |
+
* 3. The reference count of an acquired thread state can be controlled. This
|
| 43 |
+
* can be handy to prevent cases where callbacks issued from an external
|
| 44 |
+
* thread would otherwise constantly construct and destroy thread state data
|
| 45 |
+
* structures.
|
| 46 |
+
*
|
| 47 |
+
* See the Python bindings of NanoGUI (http://github.com/wjakob/nanogui) for an
|
| 48 |
+
* example which uses features 2 and 3 to migrate the Python thread of
|
| 49 |
+
* execution to another thread (to run the event loop on the original thread,
|
| 50 |
+
* in this case).
|
| 51 |
+
*/
|
| 52 |
+
|
| 53 |
+
class gil_scoped_acquire {
|
| 54 |
+
public:
|
| 55 |
+
PYBIND11_NOINLINE gil_scoped_acquire() {
|
| 56 |
+
auto &internals = detail::get_internals();
|
| 57 |
+
tstate = (PyThreadState *) PYBIND11_TLS_GET_VALUE(internals.tstate);
|
| 58 |
+
|
| 59 |
+
if (!tstate) {
|
| 60 |
+
/* Check if the GIL was acquired using the PyGILState_* API instead (e.g. if
|
| 61 |
+
calling from a Python thread). Since we use a different key, this ensures
|
| 62 |
+
we don't create a new thread state and deadlock in PyEval_AcquireThread
|
| 63 |
+
below. Note we don't save this state with internals.tstate, since we don't
|
| 64 |
+
create it we would fail to clear it (its reference count should be > 0). */
|
| 65 |
+
tstate = PyGILState_GetThisThreadState();
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
if (!tstate) {
|
| 69 |
+
tstate = PyThreadState_New(internals.istate);
|
| 70 |
+
# if defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 71 |
+
if (!tstate) {
|
| 72 |
+
pybind11_fail("scoped_acquire: could not create thread state!");
|
| 73 |
+
}
|
| 74 |
+
# endif
|
| 75 |
+
tstate->gilstate_counter = 0;
|
| 76 |
+
PYBIND11_TLS_REPLACE_VALUE(internals.tstate, tstate);
|
| 77 |
+
} else {
|
| 78 |
+
release = detail::get_thread_state_unchecked() != tstate;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
if (release) {
|
| 82 |
+
PyEval_AcquireThread(tstate);
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
inc_ref();
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
gil_scoped_acquire(const gil_scoped_acquire &) = delete;
|
| 89 |
+
gil_scoped_acquire &operator=(const gil_scoped_acquire &) = delete;
|
| 90 |
+
|
| 91 |
+
void inc_ref() { ++tstate->gilstate_counter; }
|
| 92 |
+
|
| 93 |
+
PYBIND11_NOINLINE void dec_ref() {
|
| 94 |
+
--tstate->gilstate_counter;
|
| 95 |
+
# if defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 96 |
+
if (detail::get_thread_state_unchecked() != tstate) {
|
| 97 |
+
pybind11_fail("scoped_acquire::dec_ref(): thread state must be current!");
|
| 98 |
+
}
|
| 99 |
+
if (tstate->gilstate_counter < 0) {
|
| 100 |
+
pybind11_fail("scoped_acquire::dec_ref(): reference count underflow!");
|
| 101 |
+
}
|
| 102 |
+
# endif
|
| 103 |
+
if (tstate->gilstate_counter == 0) {
|
| 104 |
+
# if defined(PYBIND11_DETAILED_ERROR_MESSAGES)
|
| 105 |
+
if (!release) {
|
| 106 |
+
pybind11_fail("scoped_acquire::dec_ref(): internal error!");
|
| 107 |
+
}
|
| 108 |
+
# endif
|
| 109 |
+
PyThreadState_Clear(tstate);
|
| 110 |
+
if (active) {
|
| 111 |
+
PyThreadState_DeleteCurrent();
|
| 112 |
+
}
|
| 113 |
+
PYBIND11_TLS_DELETE_VALUE(detail::get_internals().tstate);
|
| 114 |
+
release = false;
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
/// This method will disable the PyThreadState_DeleteCurrent call and the
|
| 119 |
+
/// GIL won't be acquired. This method should be used if the interpreter
|
| 120 |
+
/// could be shutting down when this is called, as thread deletion is not
|
| 121 |
+
/// allowed during shutdown. Check _Py_IsFinalizing() on Python 3.7+, and
|
| 122 |
+
/// protect subsequent code.
|
| 123 |
+
PYBIND11_NOINLINE void disarm() { active = false; }
|
| 124 |
+
|
| 125 |
+
PYBIND11_NOINLINE ~gil_scoped_acquire() {
|
| 126 |
+
dec_ref();
|
| 127 |
+
if (release) {
|
| 128 |
+
PyEval_SaveThread();
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
private:
|
| 133 |
+
PyThreadState *tstate = nullptr;
|
| 134 |
+
bool release = true;
|
| 135 |
+
bool active = true;
|
| 136 |
+
};
|
| 137 |
+
|
| 138 |
+
class gil_scoped_release {
|
| 139 |
+
public:
|
| 140 |
+
// PRECONDITION: The GIL must be held when this constructor is called.
|
| 141 |
+
explicit gil_scoped_release(bool disassoc = false) : disassoc(disassoc) {
|
| 142 |
+
assert(PyGILState_Check());
|
| 143 |
+
// `get_internals()` must be called here unconditionally in order to initialize
|
| 144 |
+
// `internals.tstate` for subsequent `gil_scoped_acquire` calls. Otherwise, an
|
| 145 |
+
// initialization race could occur as multiple threads try `gil_scoped_acquire`.
|
| 146 |
+
auto &internals = detail::get_internals();
|
| 147 |
+
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
|
| 148 |
+
tstate = PyEval_SaveThread();
|
| 149 |
+
if (disassoc) {
|
| 150 |
+
// Python >= 3.7 can remove this, it's an int before 3.7
|
| 151 |
+
// NOLINTNEXTLINE(readability-qualified-auto)
|
| 152 |
+
auto key = internals.tstate;
|
| 153 |
+
PYBIND11_TLS_DELETE_VALUE(key);
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
gil_scoped_release(const gil_scoped_release &) = delete;
|
| 158 |
+
gil_scoped_release &operator=(const gil_scoped_release &) = delete;
|
| 159 |
+
|
| 160 |
+
/// This method will disable the PyThreadState_DeleteCurrent call and the
|
| 161 |
+
/// GIL won't be acquired. This method should be used if the interpreter
|
| 162 |
+
/// could be shutting down when this is called, as thread deletion is not
|
| 163 |
+
/// allowed during shutdown. Check _Py_IsFinalizing() on Python 3.7+, and
|
| 164 |
+
/// protect subsequent code.
|
| 165 |
+
PYBIND11_NOINLINE void disarm() { active = false; }
|
| 166 |
+
|
| 167 |
+
~gil_scoped_release() {
|
| 168 |
+
if (!tstate) {
|
| 169 |
+
return;
|
| 170 |
+
}
|
| 171 |
+
// `PyEval_RestoreThread()` should not be called if runtime is finalizing
|
| 172 |
+
if (active) {
|
| 173 |
+
PyEval_RestoreThread(tstate);
|
| 174 |
+
}
|
| 175 |
+
if (disassoc) {
|
| 176 |
+
// Python >= 3.7 can remove this, it's an int before 3.7
|
| 177 |
+
// NOLINTNEXTLINE(readability-qualified-auto)
|
| 178 |
+
auto key = detail::get_internals().tstate;
|
| 179 |
+
PYBIND11_TLS_REPLACE_VALUE(key, tstate);
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
private:
|
| 184 |
+
PyThreadState *tstate;
|
| 185 |
+
bool disassoc;
|
| 186 |
+
bool active = true;
|
| 187 |
+
};
|
| 188 |
+
|
| 189 |
+
#else // PYBIND11_SIMPLE_GIL_MANAGEMENT
|
| 190 |
+
|
| 191 |
+
class gil_scoped_acquire {
|
| 192 |
+
PyGILState_STATE state;
|
| 193 |
+
|
| 194 |
+
public:
|
| 195 |
+
gil_scoped_acquire() : state{PyGILState_Ensure()} {}
|
| 196 |
+
gil_scoped_acquire(const gil_scoped_acquire &) = delete;
|
| 197 |
+
gil_scoped_acquire &operator=(const gil_scoped_acquire &) = delete;
|
| 198 |
+
~gil_scoped_acquire() { PyGILState_Release(state); }
|
| 199 |
+
void disarm() {}
|
| 200 |
+
};
|
| 201 |
+
|
| 202 |
+
class gil_scoped_release {
|
| 203 |
+
PyThreadState *state;
|
| 204 |
+
|
| 205 |
+
public:
|
| 206 |
+
// PRECONDITION: The GIL must be held when this constructor is called.
|
| 207 |
+
gil_scoped_release() {
|
| 208 |
+
assert(PyGILState_Check());
|
| 209 |
+
state = PyEval_SaveThread();
|
| 210 |
+
}
|
| 211 |
+
gil_scoped_release(const gil_scoped_release &) = delete;
|
| 212 |
+
gil_scoped_release &operator=(const gil_scoped_release &) = delete;
|
| 213 |
+
~gil_scoped_release() { PyEval_RestoreThread(state); }
|
| 214 |
+
void disarm() {}
|
| 215 |
+
};
|
| 216 |
+
|
| 217 |
+
#endif // PYBIND11_SIMPLE_GIL_MANAGEMENT
|
| 218 |
+
|
| 219 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/gil_safe_call_once.h
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 The pybind Community.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include "detail/common.h"
|
| 6 |
+
#include "gil.h"
|
| 7 |
+
|
| 8 |
+
#include <cassert>
|
| 9 |
+
#include <mutex>
|
| 10 |
+
|
| 11 |
+
#ifdef Py_GIL_DISABLED
|
| 12 |
+
# include <atomic>
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 16 |
+
|
| 17 |
+
// Use the `gil_safe_call_once_and_store` class below instead of the naive
|
| 18 |
+
//
|
| 19 |
+
// static auto imported_obj = py::module_::import("module_name"); // BAD, DO NOT USE!
|
| 20 |
+
//
|
| 21 |
+
// which has two serious issues:
|
| 22 |
+
//
|
| 23 |
+
// 1. Py_DECREF() calls potentially after the Python interpreter was finalized already, and
|
| 24 |
+
// 2. deadlocks in multi-threaded processes (because of missing lock ordering).
|
| 25 |
+
//
|
| 26 |
+
// The following alternative avoids both problems:
|
| 27 |
+
//
|
| 28 |
+
// PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object> storage;
|
| 29 |
+
// auto &imported_obj = storage // Do NOT make this `static`!
|
| 30 |
+
// .call_once_and_store_result([]() {
|
| 31 |
+
// return py::module_::import("module_name");
|
| 32 |
+
// })
|
| 33 |
+
// .get_stored();
|
| 34 |
+
//
|
| 35 |
+
// The parameter of `call_once_and_store_result()` must be callable. It can make
|
| 36 |
+
// CPython API calls, and in particular, it can temporarily release the GIL.
|
| 37 |
+
//
|
| 38 |
+
// `T` can be any C++ type, it does not have to involve CPython API types.
|
| 39 |
+
//
|
| 40 |
+
// The behavior with regard to signals, e.g. `SIGINT` (`KeyboardInterrupt`),
|
| 41 |
+
// is not ideal. If the main thread is the one to actually run the `Callable`,
|
| 42 |
+
// then a `KeyboardInterrupt` will interrupt it if it is running normal Python
|
| 43 |
+
// code. The situation is different if a non-main thread runs the
|
| 44 |
+
// `Callable`, and then the main thread starts waiting for it to complete:
|
| 45 |
+
// a `KeyboardInterrupt` will not interrupt the non-main thread, but it will
|
| 46 |
+
// get processed only when it is the main thread's turn again and it is running
|
| 47 |
+
// normal Python code. However, this will be unnoticeable for quick call-once
|
| 48 |
+
// functions, which is usually the case.
|
| 49 |
+
template <typename T>
|
| 50 |
+
class gil_safe_call_once_and_store {
|
| 51 |
+
public:
|
| 52 |
+
// PRECONDITION: The GIL must be held when `call_once_and_store_result()` is called.
|
| 53 |
+
template <typename Callable>
|
| 54 |
+
gil_safe_call_once_and_store &call_once_and_store_result(Callable &&fn) {
|
| 55 |
+
if (!is_initialized_) { // This read is guarded by the GIL.
|
| 56 |
+
// Multiple threads may enter here, because the GIL is released in the next line and
|
| 57 |
+
// CPython API calls in the `fn()` call below may release and reacquire the GIL.
|
| 58 |
+
gil_scoped_release gil_rel; // Needed to establish lock ordering.
|
| 59 |
+
std::call_once(once_flag_, [&] {
|
| 60 |
+
// Only one thread will ever enter here.
|
| 61 |
+
gil_scoped_acquire gil_acq;
|
| 62 |
+
::new (storage_) T(fn()); // fn may release, but will reacquire, the GIL.
|
| 63 |
+
is_initialized_ = true; // This write is guarded by the GIL.
|
| 64 |
+
});
|
| 65 |
+
// All threads will observe `is_initialized_` as true here.
|
| 66 |
+
}
|
| 67 |
+
// Intentionally not returning `T &` to ensure the calling code is self-documenting.
|
| 68 |
+
return *this;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
// This must only be called after `call_once_and_store_result()` was called.
|
| 72 |
+
T &get_stored() {
|
| 73 |
+
assert(is_initialized_);
|
| 74 |
+
PYBIND11_WARNING_PUSH
|
| 75 |
+
#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ < 5
|
| 76 |
+
// Needed for gcc 4.8.5
|
| 77 |
+
PYBIND11_WARNING_DISABLE_GCC("-Wstrict-aliasing")
|
| 78 |
+
#endif
|
| 79 |
+
return *reinterpret_cast<T *>(storage_);
|
| 80 |
+
PYBIND11_WARNING_POP
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
constexpr gil_safe_call_once_and_store() = default;
|
| 84 |
+
PYBIND11_DTOR_CONSTEXPR ~gil_safe_call_once_and_store() = default;
|
| 85 |
+
|
| 86 |
+
private:
|
| 87 |
+
alignas(T) char storage_[sizeof(T)] = {};
|
| 88 |
+
std::once_flag once_flag_ = {};
|
| 89 |
+
#ifdef Py_GIL_DISABLED
|
| 90 |
+
std::atomic_bool
|
| 91 |
+
#else
|
| 92 |
+
bool
|
| 93 |
+
#endif
|
| 94 |
+
is_initialized_{false};
|
| 95 |
+
// The `is_initialized_`-`storage_` pair is very similar to `std::optional`,
|
| 96 |
+
// but the latter does not have the triviality properties of former,
|
| 97 |
+
// therefore `std::optional` is not a viable alternative here.
|
| 98 |
+
};
|
| 99 |
+
|
| 100 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/iostream.h
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2017 Henry F. Schreiner
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
|
| 9 |
+
WARNING: The implementation in this file is NOT thread safe. Multiple
|
| 10 |
+
threads writing to a redirected ostream concurrently cause data races
|
| 11 |
+
and potentially buffer overflows. Therefore it is currently a requirement
|
| 12 |
+
that all (possibly) concurrent redirected ostream writes are protected by
|
| 13 |
+
a mutex.
|
| 14 |
+
#HelpAppreciated: Work on iostream.h thread safety.
|
| 15 |
+
For more background see the discussions under
|
| 16 |
+
https://github.com/pybind/pybind11/pull/2982 and
|
| 17 |
+
https://github.com/pybind/pybind11/pull/2995.
|
| 18 |
+
*/
|
| 19 |
+
|
| 20 |
+
#pragma once
|
| 21 |
+
|
| 22 |
+
#include "pybind11.h"
|
| 23 |
+
|
| 24 |
+
#include <algorithm>
|
| 25 |
+
#include <cstring>
|
| 26 |
+
#include <iostream>
|
| 27 |
+
#include <iterator>
|
| 28 |
+
#include <memory>
|
| 29 |
+
#include <ostream>
|
| 30 |
+
#include <streambuf>
|
| 31 |
+
#include <string>
|
| 32 |
+
#include <utility>
|
| 33 |
+
|
| 34 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 35 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 36 |
+
|
| 37 |
+
// Buffer that writes to Python instead of C++
|
| 38 |
+
class pythonbuf : public std::streambuf {
|
| 39 |
+
private:
|
| 40 |
+
using traits_type = std::streambuf::traits_type;
|
| 41 |
+
|
| 42 |
+
const size_t buf_size;
|
| 43 |
+
std::unique_ptr<char[]> d_buffer;
|
| 44 |
+
object pywrite;
|
| 45 |
+
object pyflush;
|
| 46 |
+
|
| 47 |
+
int overflow(int c) override {
|
| 48 |
+
if (!traits_type::eq_int_type(c, traits_type::eof())) {
|
| 49 |
+
*pptr() = traits_type::to_char_type(c);
|
| 50 |
+
pbump(1);
|
| 51 |
+
}
|
| 52 |
+
return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof();
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
// Computes how many bytes at the end of the buffer are part of an
|
| 56 |
+
// incomplete sequence of UTF-8 bytes.
|
| 57 |
+
// Precondition: pbase() < pptr()
|
| 58 |
+
size_t utf8_remainder() const {
|
| 59 |
+
const auto rbase = std::reverse_iterator<char *>(pbase());
|
| 60 |
+
const auto rpptr = std::reverse_iterator<char *>(pptr());
|
| 61 |
+
auto is_ascii = [](char c) { return (static_cast<unsigned char>(c) & 0x80) == 0x00; };
|
| 62 |
+
auto is_leading = [](char c) { return (static_cast<unsigned char>(c) & 0xC0) == 0xC0; };
|
| 63 |
+
auto is_leading_2b = [](char c) { return static_cast<unsigned char>(c) <= 0xDF; };
|
| 64 |
+
auto is_leading_3b = [](char c) { return static_cast<unsigned char>(c) <= 0xEF; };
|
| 65 |
+
// If the last character is ASCII, there are no incomplete code points
|
| 66 |
+
if (is_ascii(*rpptr)) {
|
| 67 |
+
return 0;
|
| 68 |
+
}
|
| 69 |
+
// Otherwise, work back from the end of the buffer and find the first
|
| 70 |
+
// UTF-8 leading byte
|
| 71 |
+
const auto rpend = rbase - rpptr >= 3 ? rpptr + 3 : rbase;
|
| 72 |
+
const auto leading = std::find_if(rpptr, rpend, is_leading);
|
| 73 |
+
if (leading == rbase) {
|
| 74 |
+
return 0;
|
| 75 |
+
}
|
| 76 |
+
const auto dist = static_cast<size_t>(leading - rpptr);
|
| 77 |
+
size_t remainder = 0;
|
| 78 |
+
|
| 79 |
+
if (dist == 0) {
|
| 80 |
+
remainder = 1; // 1-byte code point is impossible
|
| 81 |
+
} else if (dist == 1) {
|
| 82 |
+
remainder = is_leading_2b(*leading) ? 0 : dist + 1;
|
| 83 |
+
} else if (dist == 2) {
|
| 84 |
+
remainder = is_leading_3b(*leading) ? 0 : dist + 1;
|
| 85 |
+
}
|
| 86 |
+
// else if (dist >= 3), at least 4 bytes before encountering an UTF-8
|
| 87 |
+
// leading byte, either no remainder or invalid UTF-8.
|
| 88 |
+
// Invalid UTF-8 will cause an exception later when converting
|
| 89 |
+
// to a Python string, so that's not handled here.
|
| 90 |
+
return remainder;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// This function must be non-virtual to be called in a destructor.
|
| 94 |
+
int _sync() {
|
| 95 |
+
if (pbase() != pptr()) { // If buffer is not empty
|
| 96 |
+
gil_scoped_acquire tmp;
|
| 97 |
+
// This subtraction cannot be negative, so dropping the sign.
|
| 98 |
+
auto size = static_cast<size_t>(pptr() - pbase());
|
| 99 |
+
size_t remainder = utf8_remainder();
|
| 100 |
+
|
| 101 |
+
if (size > remainder) {
|
| 102 |
+
str line(pbase(), size - remainder);
|
| 103 |
+
pywrite(std::move(line));
|
| 104 |
+
pyflush();
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
// Copy the remainder at the end of the buffer to the beginning:
|
| 108 |
+
if (remainder > 0) {
|
| 109 |
+
std::memmove(pbase(), pptr() - remainder, remainder);
|
| 110 |
+
}
|
| 111 |
+
setp(pbase(), epptr());
|
| 112 |
+
pbump(static_cast<int>(remainder));
|
| 113 |
+
}
|
| 114 |
+
return 0;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
int sync() override { return _sync(); }
|
| 118 |
+
|
| 119 |
+
public:
|
| 120 |
+
explicit pythonbuf(const object &pyostream, size_t buffer_size = 1024)
|
| 121 |
+
: buf_size(buffer_size), d_buffer(new char[buf_size]), pywrite(pyostream.attr("write")),
|
| 122 |
+
pyflush(pyostream.attr("flush")) {
|
| 123 |
+
setp(d_buffer.get(), d_buffer.get() + buf_size - 1);
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
pythonbuf(pythonbuf &&) = default;
|
| 127 |
+
|
| 128 |
+
/// Sync before destroy
|
| 129 |
+
~pythonbuf() override { _sync(); }
|
| 130 |
+
};
|
| 131 |
+
|
| 132 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 133 |
+
|
| 134 |
+
/** \rst
|
| 135 |
+
This a move-only guard that redirects output.
|
| 136 |
+
|
| 137 |
+
.. code-block:: cpp
|
| 138 |
+
|
| 139 |
+
#include <pybind11/iostream.h>
|
| 140 |
+
|
| 141 |
+
...
|
| 142 |
+
|
| 143 |
+
{
|
| 144 |
+
py::scoped_ostream_redirect output;
|
| 145 |
+
std::cout << "Hello, World!"; // Python stdout
|
| 146 |
+
} // <-- return std::cout to normal
|
| 147 |
+
|
| 148 |
+
You can explicitly pass the c++ stream and the python object,
|
| 149 |
+
for example to guard stderr instead.
|
| 150 |
+
|
| 151 |
+
.. code-block:: cpp
|
| 152 |
+
|
| 153 |
+
{
|
| 154 |
+
py::scoped_ostream_redirect output{
|
| 155 |
+
std::cerr, py::module::import("sys").attr("stderr")};
|
| 156 |
+
std::cout << "Hello, World!";
|
| 157 |
+
}
|
| 158 |
+
\endrst */
|
| 159 |
+
class scoped_ostream_redirect {
|
| 160 |
+
protected:
|
| 161 |
+
std::streambuf *old;
|
| 162 |
+
std::ostream &costream;
|
| 163 |
+
detail::pythonbuf buffer;
|
| 164 |
+
|
| 165 |
+
public:
|
| 166 |
+
explicit scoped_ostream_redirect(std::ostream &costream = std::cout,
|
| 167 |
+
const object &pyostream
|
| 168 |
+
= module_::import("sys").attr("stdout"))
|
| 169 |
+
: costream(costream), buffer(pyostream) {
|
| 170 |
+
old = costream.rdbuf(&buffer);
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
~scoped_ostream_redirect() { costream.rdbuf(old); }
|
| 174 |
+
|
| 175 |
+
scoped_ostream_redirect(const scoped_ostream_redirect &) = delete;
|
| 176 |
+
scoped_ostream_redirect(scoped_ostream_redirect &&other) = default;
|
| 177 |
+
scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete;
|
| 178 |
+
scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete;
|
| 179 |
+
};
|
| 180 |
+
|
| 181 |
+
/** \rst
|
| 182 |
+
Like `scoped_ostream_redirect`, but redirects cerr by default. This class
|
| 183 |
+
is provided primary to make ``py::call_guard`` easier to make.
|
| 184 |
+
|
| 185 |
+
.. code-block:: cpp
|
| 186 |
+
|
| 187 |
+
m.def("noisy_func", &noisy_func,
|
| 188 |
+
py::call_guard<scoped_ostream_redirect,
|
| 189 |
+
scoped_estream_redirect>());
|
| 190 |
+
|
| 191 |
+
\endrst */
|
| 192 |
+
class scoped_estream_redirect : public scoped_ostream_redirect {
|
| 193 |
+
public:
|
| 194 |
+
explicit scoped_estream_redirect(std::ostream &costream = std::cerr,
|
| 195 |
+
const object &pyostream
|
| 196 |
+
= module_::import("sys").attr("stderr"))
|
| 197 |
+
: scoped_ostream_redirect(costream, pyostream) {}
|
| 198 |
+
};
|
| 199 |
+
|
| 200 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 201 |
+
|
| 202 |
+
// Class to redirect output as a context manager. C++ backend.
|
| 203 |
+
class OstreamRedirect {
|
| 204 |
+
bool do_stdout_;
|
| 205 |
+
bool do_stderr_;
|
| 206 |
+
std::unique_ptr<scoped_ostream_redirect> redirect_stdout;
|
| 207 |
+
std::unique_ptr<scoped_estream_redirect> redirect_stderr;
|
| 208 |
+
|
| 209 |
+
public:
|
| 210 |
+
explicit OstreamRedirect(bool do_stdout = true, bool do_stderr = true)
|
| 211 |
+
: do_stdout_(do_stdout), do_stderr_(do_stderr) {}
|
| 212 |
+
|
| 213 |
+
void enter() {
|
| 214 |
+
if (do_stdout_) {
|
| 215 |
+
redirect_stdout.reset(new scoped_ostream_redirect());
|
| 216 |
+
}
|
| 217 |
+
if (do_stderr_) {
|
| 218 |
+
redirect_stderr.reset(new scoped_estream_redirect());
|
| 219 |
+
}
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
void exit() {
|
| 223 |
+
redirect_stdout.reset();
|
| 224 |
+
redirect_stderr.reset();
|
| 225 |
+
}
|
| 226 |
+
};
|
| 227 |
+
|
| 228 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 229 |
+
|
| 230 |
+
/** \rst
|
| 231 |
+
This is a helper function to add a C++ redirect context manager to Python
|
| 232 |
+
instead of using a C++ guard. To use it, add the following to your binding code:
|
| 233 |
+
|
| 234 |
+
.. code-block:: cpp
|
| 235 |
+
|
| 236 |
+
#include <pybind11/iostream.h>
|
| 237 |
+
|
| 238 |
+
...
|
| 239 |
+
|
| 240 |
+
py::add_ostream_redirect(m, "ostream_redirect");
|
| 241 |
+
|
| 242 |
+
You now have a Python context manager that redirects your output:
|
| 243 |
+
|
| 244 |
+
.. code-block:: python
|
| 245 |
+
|
| 246 |
+
with m.ostream_redirect():
|
| 247 |
+
m.print_to_cout_function()
|
| 248 |
+
|
| 249 |
+
This manager can optionally be told which streams to operate on:
|
| 250 |
+
|
| 251 |
+
.. code-block:: python
|
| 252 |
+
|
| 253 |
+
with m.ostream_redirect(stdout=true, stderr=true):
|
| 254 |
+
m.noisy_function_with_error_printing()
|
| 255 |
+
|
| 256 |
+
\endrst */
|
| 257 |
+
inline class_<detail::OstreamRedirect>
|
| 258 |
+
add_ostream_redirect(module_ m, const std::string &name = "ostream_redirect") {
|
| 259 |
+
return class_<detail::OstreamRedirect>(std::move(m), name.c_str(), module_local())
|
| 260 |
+
.def(init<bool, bool>(), arg("stdout") = true, arg("stderr") = true)
|
| 261 |
+
.def("__enter__", &detail::OstreamRedirect::enter)
|
| 262 |
+
.def("__exit__", [](detail::OstreamRedirect &self_, const args &) { self_.exit(); });
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/numpy.h
ADDED
|
@@ -0,0 +1,2139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "pybind11.h"
|
| 13 |
+
#include "detail/common.h"
|
| 14 |
+
#include "complex.h"
|
| 15 |
+
#include "gil_safe_call_once.h"
|
| 16 |
+
#include "pytypes.h"
|
| 17 |
+
|
| 18 |
+
#include <algorithm>
|
| 19 |
+
#include <array>
|
| 20 |
+
#include <cstdint>
|
| 21 |
+
#include <cstdlib>
|
| 22 |
+
#include <cstring>
|
| 23 |
+
#include <functional>
|
| 24 |
+
#include <numeric>
|
| 25 |
+
#include <sstream>
|
| 26 |
+
#include <string>
|
| 27 |
+
#include <type_traits>
|
| 28 |
+
#include <typeindex>
|
| 29 |
+
#include <utility>
|
| 30 |
+
#include <vector>
|
| 31 |
+
|
| 32 |
+
#if defined(PYBIND11_NUMPY_1_ONLY) && !defined(PYBIND11_INTERNAL_NUMPY_1_ONLY_DETECTED)
|
| 33 |
+
# error PYBIND11_NUMPY_1_ONLY must be defined before any pybind11 header is included.
|
| 34 |
+
#endif
|
| 35 |
+
|
| 36 |
+
/* This will be true on all flat address space platforms and allows us to reduce the
|
| 37 |
+
whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size
|
| 38 |
+
and dimension types (e.g. shape, strides, indexing), instead of inflicting this
|
| 39 |
+
upon the library user.
|
| 40 |
+
Note that NumPy 2 now uses ssize_t for `npy_intp` to simplify this. */
|
| 41 |
+
static_assert(sizeof(::pybind11::ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t");
|
| 42 |
+
static_assert(std::is_signed<Py_intptr_t>::value, "Py_intptr_t must be signed");
|
| 43 |
+
// We now can reinterpret_cast between py::ssize_t and Py_intptr_t (MSVC + PyPy cares)
|
| 44 |
+
|
| 45 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 46 |
+
|
| 47 |
+
PYBIND11_WARNING_DISABLE_MSVC(4127)
|
| 48 |
+
|
| 49 |
+
class dtype; // Forward declaration
|
| 50 |
+
class array; // Forward declaration
|
| 51 |
+
|
| 52 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 53 |
+
|
| 54 |
+
template <>
|
| 55 |
+
struct handle_type_name<dtype> {
|
| 56 |
+
static constexpr auto name = const_name("numpy.dtype");
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
template <>
|
| 60 |
+
struct handle_type_name<array> {
|
| 61 |
+
static constexpr auto name = const_name("numpy.ndarray");
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
template <typename type, typename SFINAE = void>
|
| 65 |
+
struct npy_format_descriptor;
|
| 66 |
+
|
| 67 |
+
/* NumPy 1 proxy (always includes legacy fields) */
|
| 68 |
+
struct PyArrayDescr1_Proxy {
|
| 69 |
+
PyObject_HEAD
|
| 70 |
+
PyObject *typeobj;
|
| 71 |
+
char kind;
|
| 72 |
+
char type;
|
| 73 |
+
char byteorder;
|
| 74 |
+
char flags;
|
| 75 |
+
int type_num;
|
| 76 |
+
int elsize;
|
| 77 |
+
int alignment;
|
| 78 |
+
char *subarray;
|
| 79 |
+
PyObject *fields;
|
| 80 |
+
PyObject *names;
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
#ifndef PYBIND11_NUMPY_1_ONLY
|
| 84 |
+
struct PyArrayDescr_Proxy {
|
| 85 |
+
PyObject_HEAD
|
| 86 |
+
PyObject *typeobj;
|
| 87 |
+
char kind;
|
| 88 |
+
char type;
|
| 89 |
+
char byteorder;
|
| 90 |
+
char _former_flags;
|
| 91 |
+
int type_num;
|
| 92 |
+
/* Additional fields are NumPy version specific. */
|
| 93 |
+
};
|
| 94 |
+
#else
|
| 95 |
+
/* NumPy 1.x only, we can expose all fields */
|
| 96 |
+
using PyArrayDescr_Proxy = PyArrayDescr1_Proxy;
|
| 97 |
+
#endif
|
| 98 |
+
|
| 99 |
+
/* NumPy 2 proxy, including legacy fields */
|
| 100 |
+
struct PyArrayDescr2_Proxy {
|
| 101 |
+
PyObject_HEAD
|
| 102 |
+
PyObject *typeobj;
|
| 103 |
+
char kind;
|
| 104 |
+
char type;
|
| 105 |
+
char byteorder;
|
| 106 |
+
char _former_flags;
|
| 107 |
+
int type_num;
|
| 108 |
+
std::uint64_t flags;
|
| 109 |
+
ssize_t elsize;
|
| 110 |
+
ssize_t alignment;
|
| 111 |
+
PyObject *metadata;
|
| 112 |
+
Py_hash_t hash;
|
| 113 |
+
void *reserved_null[2];
|
| 114 |
+
/* The following fields only exist if 0 <= type_num < 2056 */
|
| 115 |
+
char *subarray;
|
| 116 |
+
PyObject *fields;
|
| 117 |
+
PyObject *names;
|
| 118 |
+
};
|
| 119 |
+
|
| 120 |
+
struct PyArray_Proxy {
|
| 121 |
+
PyObject_HEAD
|
| 122 |
+
char *data;
|
| 123 |
+
int nd;
|
| 124 |
+
ssize_t *dimensions;
|
| 125 |
+
ssize_t *strides;
|
| 126 |
+
PyObject *base;
|
| 127 |
+
PyObject *descr;
|
| 128 |
+
int flags;
|
| 129 |
+
};
|
| 130 |
+
|
| 131 |
+
struct PyVoidScalarObject_Proxy {
|
| 132 |
+
PyObject_VAR_HEAD char *obval;
|
| 133 |
+
PyArrayDescr_Proxy *descr;
|
| 134 |
+
int flags;
|
| 135 |
+
PyObject *base;
|
| 136 |
+
};
|
| 137 |
+
|
| 138 |
+
struct numpy_type_info {
|
| 139 |
+
PyObject *dtype_ptr;
|
| 140 |
+
std::string format_str;
|
| 141 |
+
};
|
| 142 |
+
|
| 143 |
+
struct numpy_internals {
|
| 144 |
+
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
|
| 145 |
+
|
| 146 |
+
numpy_type_info *get_type_info(const std::type_info &tinfo, bool throw_if_missing = true) {
|
| 147 |
+
auto it = registered_dtypes.find(std::type_index(tinfo));
|
| 148 |
+
if (it != registered_dtypes.end()) {
|
| 149 |
+
return &(it->second);
|
| 150 |
+
}
|
| 151 |
+
if (throw_if_missing) {
|
| 152 |
+
pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
|
| 153 |
+
}
|
| 154 |
+
return nullptr;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
template <typename T>
|
| 158 |
+
numpy_type_info *get_type_info(bool throw_if_missing = true) {
|
| 159 |
+
return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
|
| 160 |
+
}
|
| 161 |
+
};
|
| 162 |
+
|
| 163 |
+
PYBIND11_NOINLINE void load_numpy_internals(numpy_internals *&ptr) {
|
| 164 |
+
ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
inline numpy_internals &get_numpy_internals() {
|
| 168 |
+
static numpy_internals *ptr = nullptr;
|
| 169 |
+
if (!ptr) {
|
| 170 |
+
load_numpy_internals(ptr);
|
| 171 |
+
}
|
| 172 |
+
return *ptr;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name) {
|
| 176 |
+
module_ numpy = module_::import("numpy");
|
| 177 |
+
str version_string = numpy.attr("__version__");
|
| 178 |
+
|
| 179 |
+
module_ numpy_lib = module_::import("numpy.lib");
|
| 180 |
+
object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
|
| 181 |
+
int major_version = numpy_version.attr("major").cast<int>();
|
| 182 |
+
|
| 183 |
+
#ifdef PYBIND11_NUMPY_1_ONLY
|
| 184 |
+
if (major_version >= 2) {
|
| 185 |
+
throw std::runtime_error(
|
| 186 |
+
"This extension was built with PYBIND11_NUMPY_1_ONLY defined, "
|
| 187 |
+
"but NumPy 2 is used in this process. For NumPy2 compatibility, "
|
| 188 |
+
"this extension needs to be rebuilt without the PYBIND11_NUMPY_1_ONLY define.");
|
| 189 |
+
}
|
| 190 |
+
#endif
|
| 191 |
+
/* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
|
| 192 |
+
became a private module. */
|
| 193 |
+
std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core";
|
| 194 |
+
return module_::import((numpy_core_path + "." + submodule_name).c_str());
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
template <typename T>
|
| 198 |
+
struct same_size {
|
| 199 |
+
template <typename U>
|
| 200 |
+
using as = bool_constant<sizeof(T) == sizeof(U)>;
|
| 201 |
+
};
|
| 202 |
+
|
| 203 |
+
template <typename Concrete>
|
| 204 |
+
constexpr int platform_lookup() {
|
| 205 |
+
return -1;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
// Lookup a type according to its size, and return a value corresponding to the NumPy typenum.
|
| 209 |
+
template <typename Concrete, typename T, typename... Ts, typename... Ints>
|
| 210 |
+
constexpr int platform_lookup(int I, Ints... Is) {
|
| 211 |
+
return sizeof(Concrete) == sizeof(T) ? I : platform_lookup<Concrete, Ts...>(Is...);
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
struct npy_api {
|
| 215 |
+
enum constants {
|
| 216 |
+
NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
|
| 217 |
+
NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
|
| 218 |
+
NPY_ARRAY_OWNDATA_ = 0x0004,
|
| 219 |
+
NPY_ARRAY_FORCECAST_ = 0x0010,
|
| 220 |
+
NPY_ARRAY_ENSUREARRAY_ = 0x0040,
|
| 221 |
+
NPY_ARRAY_ALIGNED_ = 0x0100,
|
| 222 |
+
NPY_ARRAY_WRITEABLE_ = 0x0400,
|
| 223 |
+
NPY_BOOL_ = 0,
|
| 224 |
+
NPY_BYTE_,
|
| 225 |
+
NPY_UBYTE_,
|
| 226 |
+
NPY_SHORT_,
|
| 227 |
+
NPY_USHORT_,
|
| 228 |
+
NPY_INT_,
|
| 229 |
+
NPY_UINT_,
|
| 230 |
+
NPY_LONG_,
|
| 231 |
+
NPY_ULONG_,
|
| 232 |
+
NPY_LONGLONG_,
|
| 233 |
+
NPY_ULONGLONG_,
|
| 234 |
+
NPY_FLOAT_,
|
| 235 |
+
NPY_DOUBLE_,
|
| 236 |
+
NPY_LONGDOUBLE_,
|
| 237 |
+
NPY_CFLOAT_,
|
| 238 |
+
NPY_CDOUBLE_,
|
| 239 |
+
NPY_CLONGDOUBLE_,
|
| 240 |
+
NPY_OBJECT_ = 17,
|
| 241 |
+
NPY_STRING_,
|
| 242 |
+
NPY_UNICODE_,
|
| 243 |
+
NPY_VOID_,
|
| 244 |
+
// Platform-dependent normalization
|
| 245 |
+
NPY_INT8_ = NPY_BYTE_,
|
| 246 |
+
NPY_UINT8_ = NPY_UBYTE_,
|
| 247 |
+
NPY_INT16_ = NPY_SHORT_,
|
| 248 |
+
NPY_UINT16_ = NPY_USHORT_,
|
| 249 |
+
// `npy_common.h` defines the integer aliases. In order, it checks:
|
| 250 |
+
// NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
|
| 251 |
+
// and assigns the alias to the first matching size, so we should check in this order.
|
| 252 |
+
NPY_INT32_
|
| 253 |
+
= platform_lookup<std::int32_t, long, int, short>(NPY_LONG_, NPY_INT_, NPY_SHORT_),
|
| 254 |
+
NPY_UINT32_ = platform_lookup<std::uint32_t, unsigned long, unsigned int, unsigned short>(
|
| 255 |
+
NPY_ULONG_, NPY_UINT_, NPY_USHORT_),
|
| 256 |
+
NPY_INT64_
|
| 257 |
+
= platform_lookup<std::int64_t, long, long long, int>(NPY_LONG_, NPY_LONGLONG_, NPY_INT_),
|
| 258 |
+
NPY_UINT64_
|
| 259 |
+
= platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
|
| 260 |
+
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
|
| 261 |
+
};
|
| 262 |
+
|
| 263 |
+
unsigned int PyArray_RUNTIME_VERSION_;
|
| 264 |
+
|
| 265 |
+
struct PyArray_Dims {
|
| 266 |
+
Py_intptr_t *ptr;
|
| 267 |
+
int len;
|
| 268 |
+
};
|
| 269 |
+
|
| 270 |
+
static npy_api &get() {
|
| 271 |
+
PYBIND11_CONSTINIT static gil_safe_call_once_and_store<npy_api> storage;
|
| 272 |
+
return storage.call_once_and_store_result(lookup).get_stored();
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
bool PyArray_Check_(PyObject *obj) const {
|
| 276 |
+
return PyObject_TypeCheck(obj, PyArray_Type_) != 0;
|
| 277 |
+
}
|
| 278 |
+
bool PyArrayDescr_Check_(PyObject *obj) const {
|
| 279 |
+
return PyObject_TypeCheck(obj, PyArrayDescr_Type_) != 0;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
|
| 283 |
+
PyObject *(*PyArray_DescrFromType_)(int);
|
| 284 |
+
PyObject *(*PyArray_NewFromDescr_)(PyTypeObject *,
|
| 285 |
+
PyObject *,
|
| 286 |
+
int,
|
| 287 |
+
Py_intptr_t const *,
|
| 288 |
+
Py_intptr_t const *,
|
| 289 |
+
void *,
|
| 290 |
+
int,
|
| 291 |
+
PyObject *);
|
| 292 |
+
// Unused. Not removed because that affects ABI of the class.
|
| 293 |
+
PyObject *(*PyArray_DescrNewFromType_)(int);
|
| 294 |
+
int (*PyArray_CopyInto_)(PyObject *, PyObject *);
|
| 295 |
+
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
|
| 296 |
+
PyTypeObject *PyArray_Type_;
|
| 297 |
+
PyTypeObject *PyVoidArrType_Type_;
|
| 298 |
+
PyTypeObject *PyArrayDescr_Type_;
|
| 299 |
+
PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
|
| 300 |
+
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
|
| 301 |
+
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
|
| 302 |
+
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
|
| 303 |
+
#ifdef PYBIND11_NUMPY_1_ONLY
|
| 304 |
+
int (*PyArray_GetArrayParamsFromObject_)(PyObject *,
|
| 305 |
+
PyObject *,
|
| 306 |
+
unsigned char,
|
| 307 |
+
PyObject **,
|
| 308 |
+
int *,
|
| 309 |
+
Py_intptr_t *,
|
| 310 |
+
PyObject **,
|
| 311 |
+
PyObject *);
|
| 312 |
+
#endif
|
| 313 |
+
PyObject *(*PyArray_Squeeze_)(PyObject *);
|
| 314 |
+
// Unused. Not removed because that affects ABI of the class.
|
| 315 |
+
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
|
| 316 |
+
PyObject *(*PyArray_Resize_)(PyObject *, PyArray_Dims *, int, int);
|
| 317 |
+
PyObject *(*PyArray_Newshape_)(PyObject *, PyArray_Dims *, int);
|
| 318 |
+
PyObject *(*PyArray_View_)(PyObject *, PyObject *, PyObject *);
|
| 319 |
+
|
| 320 |
+
private:
|
| 321 |
+
enum functions {
|
| 322 |
+
API_PyArray_GetNDArrayCFeatureVersion = 211,
|
| 323 |
+
API_PyArray_Type = 2,
|
| 324 |
+
API_PyArrayDescr_Type = 3,
|
| 325 |
+
API_PyVoidArrType_Type = 39,
|
| 326 |
+
API_PyArray_DescrFromType = 45,
|
| 327 |
+
API_PyArray_DescrFromScalar = 57,
|
| 328 |
+
API_PyArray_FromAny = 69,
|
| 329 |
+
API_PyArray_Resize = 80,
|
| 330 |
+
// CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82.
|
| 331 |
+
API_PyArray_CopyInto = 50,
|
| 332 |
+
API_PyArray_NewCopy = 85,
|
| 333 |
+
API_PyArray_NewFromDescr = 94,
|
| 334 |
+
API_PyArray_DescrNewFromType = 96,
|
| 335 |
+
API_PyArray_Newshape = 135,
|
| 336 |
+
API_PyArray_Squeeze = 136,
|
| 337 |
+
API_PyArray_View = 137,
|
| 338 |
+
API_PyArray_DescrConverter = 174,
|
| 339 |
+
API_PyArray_EquivTypes = 182,
|
| 340 |
+
#ifdef PYBIND11_NUMPY_1_ONLY
|
| 341 |
+
API_PyArray_GetArrayParamsFromObject = 278,
|
| 342 |
+
#endif
|
| 343 |
+
API_PyArray_SetBaseObject = 282
|
| 344 |
+
};
|
| 345 |
+
|
| 346 |
+
static npy_api lookup() {
|
| 347 |
+
module_ m = detail::import_numpy_core_submodule("multiarray");
|
| 348 |
+
auto c = m.attr("_ARRAY_API");
|
| 349 |
+
void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), nullptr);
|
| 350 |
+
if (api_ptr == nullptr) {
|
| 351 |
+
raise_from(PyExc_SystemError, "FAILURE obtaining numpy _ARRAY_API pointer.");
|
| 352 |
+
throw error_already_set();
|
| 353 |
+
}
|
| 354 |
+
npy_api api;
|
| 355 |
+
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
|
| 356 |
+
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
|
| 357 |
+
api.PyArray_RUNTIME_VERSION_ = api.PyArray_GetNDArrayCFeatureVersion_();
|
| 358 |
+
if (api.PyArray_RUNTIME_VERSION_ < 0x7) {
|
| 359 |
+
pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
|
| 360 |
+
}
|
| 361 |
+
DECL_NPY_API(PyArray_Type);
|
| 362 |
+
DECL_NPY_API(PyVoidArrType_Type);
|
| 363 |
+
DECL_NPY_API(PyArrayDescr_Type);
|
| 364 |
+
DECL_NPY_API(PyArray_DescrFromType);
|
| 365 |
+
DECL_NPY_API(PyArray_DescrFromScalar);
|
| 366 |
+
DECL_NPY_API(PyArray_FromAny);
|
| 367 |
+
DECL_NPY_API(PyArray_Resize);
|
| 368 |
+
DECL_NPY_API(PyArray_CopyInto);
|
| 369 |
+
DECL_NPY_API(PyArray_NewCopy);
|
| 370 |
+
DECL_NPY_API(PyArray_NewFromDescr);
|
| 371 |
+
DECL_NPY_API(PyArray_DescrNewFromType);
|
| 372 |
+
DECL_NPY_API(PyArray_Newshape);
|
| 373 |
+
DECL_NPY_API(PyArray_Squeeze);
|
| 374 |
+
DECL_NPY_API(PyArray_View);
|
| 375 |
+
DECL_NPY_API(PyArray_DescrConverter);
|
| 376 |
+
DECL_NPY_API(PyArray_EquivTypes);
|
| 377 |
+
#ifdef PYBIND11_NUMPY_1_ONLY
|
| 378 |
+
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
| 379 |
+
#endif
|
| 380 |
+
DECL_NPY_API(PyArray_SetBaseObject);
|
| 381 |
+
|
| 382 |
+
#undef DECL_NPY_API
|
| 383 |
+
return api;
|
| 384 |
+
}
|
| 385 |
+
};
|
| 386 |
+
|
| 387 |
+
inline PyArray_Proxy *array_proxy(void *ptr) { return reinterpret_cast<PyArray_Proxy *>(ptr); }
|
| 388 |
+
|
| 389 |
+
inline const PyArray_Proxy *array_proxy(const void *ptr) {
|
| 390 |
+
return reinterpret_cast<const PyArray_Proxy *>(ptr);
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
inline PyArrayDescr_Proxy *array_descriptor_proxy(PyObject *ptr) {
|
| 394 |
+
return reinterpret_cast<PyArrayDescr_Proxy *>(ptr);
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
inline const PyArrayDescr_Proxy *array_descriptor_proxy(const PyObject *ptr) {
|
| 398 |
+
return reinterpret_cast<const PyArrayDescr_Proxy *>(ptr);
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
inline const PyArrayDescr1_Proxy *array_descriptor1_proxy(const PyObject *ptr) {
|
| 402 |
+
return reinterpret_cast<const PyArrayDescr1_Proxy *>(ptr);
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
inline const PyArrayDescr2_Proxy *array_descriptor2_proxy(const PyObject *ptr) {
|
| 406 |
+
return reinterpret_cast<const PyArrayDescr2_Proxy *>(ptr);
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
inline bool check_flags(const void *ptr, int flag) {
|
| 410 |
+
return (flag == (array_proxy(ptr)->flags & flag));
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
template <typename T>
|
| 414 |
+
struct is_std_array : std::false_type {};
|
| 415 |
+
template <typename T, size_t N>
|
| 416 |
+
struct is_std_array<std::array<T, N>> : std::true_type {};
|
| 417 |
+
template <typename T>
|
| 418 |
+
struct is_complex : std::false_type {};
|
| 419 |
+
template <typename T>
|
| 420 |
+
struct is_complex<std::complex<T>> : std::true_type {};
|
| 421 |
+
|
| 422 |
+
template <typename T>
|
| 423 |
+
struct array_info_scalar {
|
| 424 |
+
using type = T;
|
| 425 |
+
static constexpr bool is_array = false;
|
| 426 |
+
static constexpr bool is_empty = false;
|
| 427 |
+
static constexpr auto extents = const_name("");
|
| 428 |
+
static void append_extents(list & /* shape */) {}
|
| 429 |
+
};
|
| 430 |
+
// Computes underlying type and a comma-separated list of extents for array
|
| 431 |
+
// types (any mix of std::array and built-in arrays). An array of char is
|
| 432 |
+
// treated as scalar because it gets special handling.
|
| 433 |
+
template <typename T>
|
| 434 |
+
struct array_info : array_info_scalar<T> {};
|
| 435 |
+
template <typename T, size_t N>
|
| 436 |
+
struct array_info<std::array<T, N>> {
|
| 437 |
+
using type = typename array_info<T>::type;
|
| 438 |
+
static constexpr bool is_array = true;
|
| 439 |
+
static constexpr bool is_empty = (N == 0) || array_info<T>::is_empty;
|
| 440 |
+
static constexpr size_t extent = N;
|
| 441 |
+
|
| 442 |
+
// appends the extents to shape
|
| 443 |
+
static void append_extents(list &shape) {
|
| 444 |
+
shape.append(N);
|
| 445 |
+
array_info<T>::append_extents(shape);
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
static constexpr auto extents = const_name<array_info<T>::is_array>(
|
| 449 |
+
::pybind11::detail::concat(const_name<N>(), array_info<T>::extents), const_name<N>());
|
| 450 |
+
};
|
| 451 |
+
// For numpy we have special handling for arrays of characters, so we don't include
|
| 452 |
+
// the size in the array extents.
|
| 453 |
+
template <size_t N>
|
| 454 |
+
struct array_info<char[N]> : array_info_scalar<char[N]> {};
|
| 455 |
+
template <size_t N>
|
| 456 |
+
struct array_info<std::array<char, N>> : array_info_scalar<std::array<char, N>> {};
|
| 457 |
+
template <typename T, size_t N>
|
| 458 |
+
struct array_info<T[N]> : array_info<std::array<T, N>> {};
|
| 459 |
+
template <typename T>
|
| 460 |
+
using remove_all_extents_t = typename array_info<T>::type;
|
| 461 |
+
|
| 462 |
+
template <typename T>
|
| 463 |
+
using is_pod_struct
|
| 464 |
+
= all_of<std::is_standard_layout<T>, // since we're accessing directly in memory
|
| 465 |
+
// we need a standard layout type
|
| 466 |
+
#if defined(__GLIBCXX__) \
|
| 467 |
+
&& (__GLIBCXX__ < 20150422 || __GLIBCXX__ == 20150426 || __GLIBCXX__ == 20150623 \
|
| 468 |
+
|| __GLIBCXX__ == 20150626 || __GLIBCXX__ == 20160803)
|
| 469 |
+
// libstdc++ < 5 (including versions 4.8.5, 4.9.3 and 4.9.4 which were released after
|
| 470 |
+
// 5) don't implement is_trivially_copyable, so approximate it
|
| 471 |
+
std::is_trivially_destructible<T>,
|
| 472 |
+
satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
|
| 473 |
+
#else
|
| 474 |
+
std::is_trivially_copyable<T>,
|
| 475 |
+
#endif
|
| 476 |
+
satisfies_none_of<T,
|
| 477 |
+
std::is_reference,
|
| 478 |
+
std::is_array,
|
| 479 |
+
is_std_array,
|
| 480 |
+
std::is_arithmetic,
|
| 481 |
+
is_complex,
|
| 482 |
+
std::is_enum>>;
|
| 483 |
+
|
| 484 |
+
// Replacement for std::is_pod (deprecated in C++20)
|
| 485 |
+
template <typename T>
|
| 486 |
+
using is_pod = all_of<std::is_standard_layout<T>, std::is_trivial<T>>;
|
| 487 |
+
|
| 488 |
+
template <ssize_t Dim = 0, typename Strides>
|
| 489 |
+
ssize_t byte_offset_unsafe(const Strides &) {
|
| 490 |
+
return 0;
|
| 491 |
+
}
|
| 492 |
+
template <ssize_t Dim = 0, typename Strides, typename... Ix>
|
| 493 |
+
ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) {
|
| 494 |
+
return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
/**
|
| 498 |
+
* Proxy class providing unsafe, unchecked const access to array data. This is constructed through
|
| 499 |
+
* the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`. `Dims`
|
| 500 |
+
* will be -1 for dimensions determined at runtime.
|
| 501 |
+
*/
|
| 502 |
+
template <typename T, ssize_t Dims>
|
| 503 |
+
class unchecked_reference {
|
| 504 |
+
protected:
|
| 505 |
+
static constexpr bool Dynamic = Dims < 0;
|
| 506 |
+
const unsigned char *data_;
|
| 507 |
+
// Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
|
| 508 |
+
// make large performance gains on big, nested loops, but requires compile-time dimensions
|
| 509 |
+
conditional_t<Dynamic, const ssize_t *, std::array<ssize_t, (size_t) Dims>> shape_, strides_;
|
| 510 |
+
const ssize_t dims_;
|
| 511 |
+
|
| 512 |
+
friend class pybind11::array;
|
| 513 |
+
// Constructor for compile-time dimensions:
|
| 514 |
+
template <bool Dyn = Dynamic>
|
| 515 |
+
unchecked_reference(const void *data,
|
| 516 |
+
const ssize_t *shape,
|
| 517 |
+
const ssize_t *strides,
|
| 518 |
+
enable_if_t<!Dyn, ssize_t>)
|
| 519 |
+
: data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
|
| 520 |
+
for (size_t i = 0; i < (size_t) dims_; i++) {
|
| 521 |
+
shape_[i] = shape[i];
|
| 522 |
+
strides_[i] = strides[i];
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
// Constructor for runtime dimensions:
|
| 526 |
+
template <bool Dyn = Dynamic>
|
| 527 |
+
unchecked_reference(const void *data,
|
| 528 |
+
const ssize_t *shape,
|
| 529 |
+
const ssize_t *strides,
|
| 530 |
+
enable_if_t<Dyn, ssize_t> dims)
|
| 531 |
+
: data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides},
|
| 532 |
+
dims_{dims} {}
|
| 533 |
+
|
| 534 |
+
public:
|
| 535 |
+
/**
|
| 536 |
+
* Unchecked const reference access to data at the given indices. For a compile-time known
|
| 537 |
+
* number of dimensions, this requires the correct number of arguments; for run-time
|
| 538 |
+
* dimensionality, this is not checked (and so is up to the caller to use safely).
|
| 539 |
+
*/
|
| 540 |
+
template <typename... Ix>
|
| 541 |
+
const T &operator()(Ix... index) const {
|
| 542 |
+
static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
|
| 543 |
+
"Invalid number of indices for unchecked array reference");
|
| 544 |
+
return *reinterpret_cast<const T *>(data_
|
| 545 |
+
+ byte_offset_unsafe(strides_, ssize_t(index)...));
|
| 546 |
+
}
|
| 547 |
+
/**
|
| 548 |
+
* Unchecked const reference access to data; this operator only participates if the reference
|
| 549 |
+
* is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
|
| 550 |
+
*/
|
| 551 |
+
template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
|
| 552 |
+
const T &operator[](ssize_t index) const {
|
| 553 |
+
return operator()(index);
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
/// Pointer access to the data at the given indices.
|
| 557 |
+
template <typename... Ix>
|
| 558 |
+
const T *data(Ix... ix) const {
|
| 559 |
+
return &operator()(ssize_t(ix)...);
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
/// Returns the item size, i.e. sizeof(T)
|
| 563 |
+
constexpr static ssize_t itemsize() { return sizeof(T); }
|
| 564 |
+
|
| 565 |
+
/// Returns the shape (i.e. size) of dimension `dim`
|
| 566 |
+
ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; }
|
| 567 |
+
|
| 568 |
+
/// Returns the number of dimensions of the array
|
| 569 |
+
ssize_t ndim() const { return dims_; }
|
| 570 |
+
|
| 571 |
+
/// Returns the total number of elements in the referenced array, i.e. the product of the
|
| 572 |
+
/// shapes
|
| 573 |
+
template <bool Dyn = Dynamic>
|
| 574 |
+
enable_if_t<!Dyn, ssize_t> size() const {
|
| 575 |
+
return std::accumulate(
|
| 576 |
+
shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies<ssize_t>());
|
| 577 |
+
}
|
| 578 |
+
template <bool Dyn = Dynamic>
|
| 579 |
+
enable_if_t<Dyn, ssize_t> size() const {
|
| 580 |
+
return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
/// Returns the total number of bytes used by the referenced data. Note that the actual span
|
| 584 |
+
/// in memory may be larger if the referenced array has non-contiguous strides (e.g. for a
|
| 585 |
+
/// slice).
|
| 586 |
+
ssize_t nbytes() const { return size() * itemsize(); }
|
| 587 |
+
};
|
| 588 |
+
|
| 589 |
+
template <typename T, ssize_t Dims>
|
| 590 |
+
class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
|
| 591 |
+
friend class pybind11::array;
|
| 592 |
+
using ConstBase = unchecked_reference<T, Dims>;
|
| 593 |
+
using ConstBase::ConstBase;
|
| 594 |
+
using ConstBase::Dynamic;
|
| 595 |
+
|
| 596 |
+
public:
|
| 597 |
+
// Bring in const-qualified versions from base class
|
| 598 |
+
using ConstBase::operator();
|
| 599 |
+
using ConstBase::operator[];
|
| 600 |
+
|
| 601 |
+
/// Mutable, unchecked access to data at the given indices.
|
| 602 |
+
template <typename... Ix>
|
| 603 |
+
T &operator()(Ix... index) {
|
| 604 |
+
static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,
|
| 605 |
+
"Invalid number of indices for unchecked array reference");
|
| 606 |
+
return const_cast<T &>(ConstBase::operator()(index...));
|
| 607 |
+
}
|
| 608 |
+
/**
|
| 609 |
+
* Mutable, unchecked access data at the given index; this operator only participates if the
|
| 610 |
+
* reference is to a 1-dimensional array (or has runtime dimensions). When present, this is
|
| 611 |
+
* exactly equivalent to `obj(index)`.
|
| 612 |
+
*/
|
| 613 |
+
template <ssize_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
|
| 614 |
+
T &operator[](ssize_t index) {
|
| 615 |
+
return operator()(index);
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
/// Mutable pointer access to the data at the given indices.
|
| 619 |
+
template <typename... Ix>
|
| 620 |
+
T *mutable_data(Ix... ix) {
|
| 621 |
+
return &operator()(ssize_t(ix)...);
|
| 622 |
+
}
|
| 623 |
+
};
|
| 624 |
+
|
| 625 |
+
template <typename T, ssize_t Dim>
|
| 626 |
+
struct type_caster<unchecked_reference<T, Dim>> {
|
| 627 |
+
static_assert(Dim == 0 && Dim > 0 /* always fail */,
|
| 628 |
+
"unchecked array proxy object is not castable");
|
| 629 |
+
};
|
| 630 |
+
template <typename T, ssize_t Dim>
|
| 631 |
+
struct type_caster<unchecked_mutable_reference<T, Dim>>
|
| 632 |
+
: type_caster<unchecked_reference<T, Dim>> {};
|
| 633 |
+
|
| 634 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 635 |
+
|
| 636 |
+
class dtype : public object {
|
| 637 |
+
public:
|
| 638 |
+
PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_)
|
| 639 |
+
|
| 640 |
+
explicit dtype(const buffer_info &info) {
|
| 641 |
+
dtype descr(_dtype_from_pep3118()(pybind11::str(info.format)));
|
| 642 |
+
// If info.itemsize == 0, use the value calculated from the format string
|
| 643 |
+
m_ptr = descr.strip_padding(info.itemsize != 0 ? info.itemsize : descr.itemsize())
|
| 644 |
+
.release()
|
| 645 |
+
.ptr();
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
explicit dtype(const pybind11::str &format) : dtype(from_args(format)) {}
|
| 649 |
+
|
| 650 |
+
explicit dtype(const std::string &format) : dtype(pybind11::str(format)) {}
|
| 651 |
+
|
| 652 |
+
explicit dtype(const char *format) : dtype(pybind11::str(format)) {}
|
| 653 |
+
|
| 654 |
+
dtype(list names, list formats, list offsets, ssize_t itemsize) {
|
| 655 |
+
dict args;
|
| 656 |
+
args["names"] = std::move(names);
|
| 657 |
+
args["formats"] = std::move(formats);
|
| 658 |
+
args["offsets"] = std::move(offsets);
|
| 659 |
+
args["itemsize"] = pybind11::int_(itemsize);
|
| 660 |
+
m_ptr = from_args(args).release().ptr();
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
/// Return dtype for the given typenum (one of the NPY_TYPES).
|
| 664 |
+
/// https://numpy.org/devdocs/reference/c-api/array.html#c.PyArray_DescrFromType
|
| 665 |
+
explicit dtype(int typenum)
|
| 666 |
+
: object(detail::npy_api::get().PyArray_DescrFromType_(typenum), stolen_t{}) {
|
| 667 |
+
if (m_ptr == nullptr) {
|
| 668 |
+
throw error_already_set();
|
| 669 |
+
}
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
/// This is essentially the same as calling numpy.dtype(args) in Python.
|
| 673 |
+
static dtype from_args(const object &args) {
|
| 674 |
+
PyObject *ptr = nullptr;
|
| 675 |
+
if ((detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) == 0) || !ptr) {
|
| 676 |
+
throw error_already_set();
|
| 677 |
+
}
|
| 678 |
+
return reinterpret_steal<dtype>(ptr);
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
/// Return dtype associated with a C++ type.
|
| 682 |
+
template <typename T>
|
| 683 |
+
static dtype of() {
|
| 684 |
+
return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
|
| 685 |
+
}
|
| 686 |
+
|
| 687 |
+
/// Size of the data type in bytes.
|
| 688 |
+
#ifdef PYBIND11_NUMPY_1_ONLY
|
| 689 |
+
ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; }
|
| 690 |
+
#else
|
| 691 |
+
ssize_t itemsize() const {
|
| 692 |
+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
|
| 693 |
+
return detail::array_descriptor1_proxy(m_ptr)->elsize;
|
| 694 |
+
}
|
| 695 |
+
return detail::array_descriptor2_proxy(m_ptr)->elsize;
|
| 696 |
+
}
|
| 697 |
+
#endif
|
| 698 |
+
|
| 699 |
+
/// Returns true for structured data types.
|
| 700 |
+
#ifdef PYBIND11_NUMPY_1_ONLY
|
| 701 |
+
bool has_fields() const { return detail::array_descriptor_proxy(m_ptr)->names != nullptr; }
|
| 702 |
+
#else
|
| 703 |
+
bool has_fields() const {
|
| 704 |
+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
|
| 705 |
+
return detail::array_descriptor1_proxy(m_ptr)->names != nullptr;
|
| 706 |
+
}
|
| 707 |
+
const auto *proxy = detail::array_descriptor2_proxy(m_ptr);
|
| 708 |
+
if (proxy->type_num < 0 || proxy->type_num >= 2056) {
|
| 709 |
+
return false;
|
| 710 |
+
}
|
| 711 |
+
return proxy->names != nullptr;
|
| 712 |
+
}
|
| 713 |
+
#endif
|
| 714 |
+
|
| 715 |
+
/// Single-character code for dtype's kind.
|
| 716 |
+
/// For example, floating point types are 'f' and integral types are 'i'.
|
| 717 |
+
char kind() const { return detail::array_descriptor_proxy(m_ptr)->kind; }
|
| 718 |
+
|
| 719 |
+
/// Single-character for dtype's type.
|
| 720 |
+
/// For example, ``float`` is 'f', ``double`` 'd', ``int`` 'i', and ``long`` 'l'.
|
| 721 |
+
char char_() const {
|
| 722 |
+
// Note: The signature, `dtype::char_` follows the naming of NumPy's
|
| 723 |
+
// public Python API (i.e., ``dtype.char``), rather than its internal
|
| 724 |
+
// C API (``PyArray_Descr::type``).
|
| 725 |
+
return detail::array_descriptor_proxy(m_ptr)->type;
|
| 726 |
+
}
|
| 727 |
+
|
| 728 |
+
/// type number of dtype.
|
| 729 |
+
int num() const {
|
| 730 |
+
// Note: The signature, `dtype::num` follows the naming of NumPy's public
|
| 731 |
+
// Python API (i.e., ``dtype.num``), rather than its internal
|
| 732 |
+
// C API (``PyArray_Descr::type_num``).
|
| 733 |
+
return detail::array_descriptor_proxy(m_ptr)->type_num;
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
/// Single character for byteorder
|
| 737 |
+
char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
|
| 738 |
+
|
| 739 |
+
/// Alignment of the data type
|
| 740 |
+
#ifdef PYBIND11_NUMPY_1_ONLY
|
| 741 |
+
int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; }
|
| 742 |
+
#else
|
| 743 |
+
ssize_t alignment() const {
|
| 744 |
+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
|
| 745 |
+
return detail::array_descriptor1_proxy(m_ptr)->alignment;
|
| 746 |
+
}
|
| 747 |
+
return detail::array_descriptor2_proxy(m_ptr)->alignment;
|
| 748 |
+
}
|
| 749 |
+
#endif
|
| 750 |
+
|
| 751 |
+
/// Flags for the array descriptor
|
| 752 |
+
#ifdef PYBIND11_NUMPY_1_ONLY
|
| 753 |
+
char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; }
|
| 754 |
+
#else
|
| 755 |
+
std::uint64_t flags() const {
|
| 756 |
+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
|
| 757 |
+
return (unsigned char) detail::array_descriptor1_proxy(m_ptr)->flags;
|
| 758 |
+
}
|
| 759 |
+
return detail::array_descriptor2_proxy(m_ptr)->flags;
|
| 760 |
+
}
|
| 761 |
+
#endif
|
| 762 |
+
|
| 763 |
+
private:
|
| 764 |
+
static object &_dtype_from_pep3118() {
|
| 765 |
+
PYBIND11_CONSTINIT static gil_safe_call_once_and_store<object> storage;
|
| 766 |
+
return storage
|
| 767 |
+
.call_once_and_store_result([]() {
|
| 768 |
+
return detail::import_numpy_core_submodule("_internal")
|
| 769 |
+
.attr("_dtype_from_pep3118");
|
| 770 |
+
})
|
| 771 |
+
.get_stored();
|
| 772 |
+
}
|
| 773 |
+
|
| 774 |
+
dtype strip_padding(ssize_t itemsize) {
|
| 775 |
+
// Recursively strip all void fields with empty names that are generated for
|
| 776 |
+
// padding fields (as of NumPy v1.11).
|
| 777 |
+
if (!has_fields()) {
|
| 778 |
+
return *this;
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
struct field_descr {
|
| 782 |
+
pybind11::str name;
|
| 783 |
+
object format;
|
| 784 |
+
pybind11::int_ offset;
|
| 785 |
+
field_descr(pybind11::str &&name, object &&format, pybind11::int_ &&offset)
|
| 786 |
+
: name{std::move(name)}, format{std::move(format)}, offset{std::move(offset)} {};
|
| 787 |
+
};
|
| 788 |
+
auto field_dict = attr("fields").cast<dict>();
|
| 789 |
+
std::vector<field_descr> field_descriptors;
|
| 790 |
+
field_descriptors.reserve(field_dict.size());
|
| 791 |
+
|
| 792 |
+
for (auto field : field_dict.attr("items")()) {
|
| 793 |
+
auto spec = field.cast<tuple>();
|
| 794 |
+
auto name = spec[0].cast<pybind11::str>();
|
| 795 |
+
auto spec_fo = spec[1].cast<tuple>();
|
| 796 |
+
auto format = spec_fo[0].cast<dtype>();
|
| 797 |
+
auto offset = spec_fo[1].cast<pybind11::int_>();
|
| 798 |
+
if ((len(name) == 0u) && format.kind() == 'V') {
|
| 799 |
+
continue;
|
| 800 |
+
}
|
| 801 |
+
field_descriptors.emplace_back(
|
| 802 |
+
std::move(name), format.strip_padding(format.itemsize()), std::move(offset));
|
| 803 |
+
}
|
| 804 |
+
|
| 805 |
+
std::sort(field_descriptors.begin(),
|
| 806 |
+
field_descriptors.end(),
|
| 807 |
+
[](const field_descr &a, const field_descr &b) {
|
| 808 |
+
return a.offset.cast<int>() < b.offset.cast<int>();
|
| 809 |
+
});
|
| 810 |
+
|
| 811 |
+
list names, formats, offsets;
|
| 812 |
+
for (auto &descr : field_descriptors) {
|
| 813 |
+
names.append(std::move(descr.name));
|
| 814 |
+
formats.append(std::move(descr.format));
|
| 815 |
+
offsets.append(std::move(descr.offset));
|
| 816 |
+
}
|
| 817 |
+
return dtype(std::move(names), std::move(formats), std::move(offsets), itemsize);
|
| 818 |
+
}
|
| 819 |
+
};
|
| 820 |
+
|
| 821 |
+
class array : public buffer {
|
| 822 |
+
public:
|
| 823 |
+
PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
|
| 824 |
+
|
| 825 |
+
enum {
|
| 826 |
+
c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
|
| 827 |
+
f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
|
| 828 |
+
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
|
| 829 |
+
};
|
| 830 |
+
|
| 831 |
+
array() : array(0, static_cast<const double *>(nullptr)) {}
|
| 832 |
+
|
| 833 |
+
using ShapeContainer = detail::any_container<ssize_t>;
|
| 834 |
+
using StridesContainer = detail::any_container<ssize_t>;
|
| 835 |
+
|
| 836 |
+
// Constructs an array taking shape/strides from arbitrary container types
|
| 837 |
+
array(const pybind11::dtype &dt,
|
| 838 |
+
ShapeContainer shape,
|
| 839 |
+
StridesContainer strides,
|
| 840 |
+
const void *ptr = nullptr,
|
| 841 |
+
handle base = handle()) {
|
| 842 |
+
|
| 843 |
+
if (strides->empty()) {
|
| 844 |
+
*strides = detail::c_strides(*shape, dt.itemsize());
|
| 845 |
+
}
|
| 846 |
+
|
| 847 |
+
auto ndim = shape->size();
|
| 848 |
+
if (ndim != strides->size()) {
|
| 849 |
+
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
|
| 850 |
+
}
|
| 851 |
+
auto descr = dt;
|
| 852 |
+
|
| 853 |
+
int flags = 0;
|
| 854 |
+
if (base && ptr) {
|
| 855 |
+
if (isinstance<array>(base)) {
|
| 856 |
+
/* Copy flags from base (except ownership bit) */
|
| 857 |
+
flags = reinterpret_borrow<array>(base).flags()
|
| 858 |
+
& ~detail::npy_api::NPY_ARRAY_OWNDATA_;
|
| 859 |
+
} else {
|
| 860 |
+
/* Writable by default, easy to downgrade later on if needed */
|
| 861 |
+
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
| 862 |
+
}
|
| 863 |
+
}
|
| 864 |
+
|
| 865 |
+
auto &api = detail::npy_api::get();
|
| 866 |
+
auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
|
| 867 |
+
api.PyArray_Type_,
|
| 868 |
+
descr.release().ptr(),
|
| 869 |
+
(int) ndim,
|
| 870 |
+
// Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
|
| 871 |
+
reinterpret_cast<Py_intptr_t *>(shape->data()),
|
| 872 |
+
reinterpret_cast<Py_intptr_t *>(strides->data()),
|
| 873 |
+
const_cast<void *>(ptr),
|
| 874 |
+
flags,
|
| 875 |
+
nullptr));
|
| 876 |
+
if (!tmp) {
|
| 877 |
+
throw error_already_set();
|
| 878 |
+
}
|
| 879 |
+
if (ptr) {
|
| 880 |
+
if (base) {
|
| 881 |
+
api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
|
| 882 |
+
} else {
|
| 883 |
+
tmp = reinterpret_steal<object>(
|
| 884 |
+
api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
|
| 885 |
+
}
|
| 886 |
+
}
|
| 887 |
+
m_ptr = tmp.release().ptr();
|
| 888 |
+
}
|
| 889 |
+
|
| 890 |
+
array(const pybind11::dtype &dt,
|
| 891 |
+
ShapeContainer shape,
|
| 892 |
+
const void *ptr = nullptr,
|
| 893 |
+
handle base = handle())
|
| 894 |
+
: array(dt, std::move(shape), {}, ptr, base) {}
|
| 895 |
+
|
| 896 |
+
template <typename T,
|
| 897 |
+
typename
|
| 898 |
+
= detail::enable_if_t<std::is_integral<T>::value && !std::is_same<bool, T>::value>>
|
| 899 |
+
array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle())
|
| 900 |
+
: array(dt, {{count}}, ptr, base) {}
|
| 901 |
+
|
| 902 |
+
template <typename T>
|
| 903 |
+
array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
|
| 904 |
+
: array(pybind11::dtype::of<T>(),
|
| 905 |
+
std::move(shape),
|
| 906 |
+
std::move(strides),
|
| 907 |
+
reinterpret_cast<const void *>(ptr),
|
| 908 |
+
base) {}
|
| 909 |
+
|
| 910 |
+
template <typename T>
|
| 911 |
+
array(ShapeContainer shape, const T *ptr, handle base = handle())
|
| 912 |
+
: array(std::move(shape), {}, ptr, base) {}
|
| 913 |
+
|
| 914 |
+
template <typename T>
|
| 915 |
+
explicit array(ssize_t count, const T *ptr, handle base = handle())
|
| 916 |
+
: array({count}, {}, ptr, base) {}
|
| 917 |
+
|
| 918 |
+
explicit array(const buffer_info &info, handle base = handle())
|
| 919 |
+
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr, base) {}
|
| 920 |
+
|
| 921 |
+
/// Array descriptor (dtype)
|
| 922 |
+
pybind11::dtype dtype() const {
|
| 923 |
+
return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
|
| 924 |
+
}
|
| 925 |
+
|
| 926 |
+
/// Total number of elements
|
| 927 |
+
ssize_t size() const {
|
| 928 |
+
return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies<ssize_t>());
|
| 929 |
+
}
|
| 930 |
+
|
| 931 |
+
/// Byte size of a single element
|
| 932 |
+
ssize_t itemsize() const { return dtype().itemsize(); }
|
| 933 |
+
|
| 934 |
+
/// Total number of bytes
|
| 935 |
+
ssize_t nbytes() const { return size() * itemsize(); }
|
| 936 |
+
|
| 937 |
+
/// Number of dimensions
|
| 938 |
+
ssize_t ndim() const { return detail::array_proxy(m_ptr)->nd; }
|
| 939 |
+
|
| 940 |
+
/// Base object
|
| 941 |
+
object base() const { return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base); }
|
| 942 |
+
|
| 943 |
+
/// Dimensions of the array
|
| 944 |
+
const ssize_t *shape() const { return detail::array_proxy(m_ptr)->dimensions; }
|
| 945 |
+
|
| 946 |
+
/// Dimension along a given axis
|
| 947 |
+
ssize_t shape(ssize_t dim) const {
|
| 948 |
+
if (dim >= ndim()) {
|
| 949 |
+
fail_dim_check(dim, "invalid axis");
|
| 950 |
+
}
|
| 951 |
+
return shape()[dim];
|
| 952 |
+
}
|
| 953 |
+
|
| 954 |
+
/// Strides of the array
|
| 955 |
+
const ssize_t *strides() const { return detail::array_proxy(m_ptr)->strides; }
|
| 956 |
+
|
| 957 |
+
/// Stride along a given axis
|
| 958 |
+
ssize_t strides(ssize_t dim) const {
|
| 959 |
+
if (dim >= ndim()) {
|
| 960 |
+
fail_dim_check(dim, "invalid axis");
|
| 961 |
+
}
|
| 962 |
+
return strides()[dim];
|
| 963 |
+
}
|
| 964 |
+
|
| 965 |
+
/// Return the NumPy array flags
|
| 966 |
+
int flags() const { return detail::array_proxy(m_ptr)->flags; }
|
| 967 |
+
|
| 968 |
+
/// If set, the array is writeable (otherwise the buffer is read-only)
|
| 969 |
+
bool writeable() const {
|
| 970 |
+
return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
|
| 971 |
+
}
|
| 972 |
+
|
| 973 |
+
/// If set, the array owns the data (will be freed when the array is deleted)
|
| 974 |
+
bool owndata() const {
|
| 975 |
+
return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
|
| 976 |
+
}
|
| 977 |
+
|
| 978 |
+
/// Pointer to the contained data. If index is not provided, points to the
|
| 979 |
+
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
| 980 |
+
template <typename... Ix>
|
| 981 |
+
const void *data(Ix... index) const {
|
| 982 |
+
return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
|
| 983 |
+
}
|
| 984 |
+
|
| 985 |
+
/// Mutable pointer to the contained data. If index is not provided, points to the
|
| 986 |
+
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
| 987 |
+
/// May throw if the array is not writeable.
|
| 988 |
+
template <typename... Ix>
|
| 989 |
+
void *mutable_data(Ix... index) {
|
| 990 |
+
check_writeable();
|
| 991 |
+
return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
/// Byte offset from beginning of the array to a given index (full or partial).
|
| 995 |
+
/// May throw if the index would lead to out of bounds access.
|
| 996 |
+
template <typename... Ix>
|
| 997 |
+
ssize_t offset_at(Ix... index) const {
|
| 998 |
+
if ((ssize_t) sizeof...(index) > ndim()) {
|
| 999 |
+
fail_dim_check(sizeof...(index), "too many indices for an array");
|
| 1000 |
+
}
|
| 1001 |
+
return byte_offset(ssize_t(index)...);
|
| 1002 |
+
}
|
| 1003 |
+
|
| 1004 |
+
ssize_t offset_at() const { return 0; }
|
| 1005 |
+
|
| 1006 |
+
/// Item count from beginning of the array to a given index (full or partial).
|
| 1007 |
+
/// May throw if the index would lead to out of bounds access.
|
| 1008 |
+
template <typename... Ix>
|
| 1009 |
+
ssize_t index_at(Ix... index) const {
|
| 1010 |
+
return offset_at(index...) / itemsize();
|
| 1011 |
+
}
|
| 1012 |
+
|
| 1013 |
+
/**
|
| 1014 |
+
* Returns a proxy object that provides access to the array's data without bounds or
|
| 1015 |
+
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
|
| 1016 |
+
* care: the array must not be destroyed or reshaped for the duration of the returned object,
|
| 1017 |
+
* and the caller must take care not to access invalid dimensions or dimension indices.
|
| 1018 |
+
*/
|
| 1019 |
+
template <typename T, ssize_t Dims = -1>
|
| 1020 |
+
detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
|
| 1021 |
+
if (Dims >= 0 && ndim() != Dims) {
|
| 1022 |
+
throw std::domain_error("array has incorrect number of dimensions: "
|
| 1023 |
+
+ std::to_string(ndim()) + "; expected "
|
| 1024 |
+
+ std::to_string(Dims));
|
| 1025 |
+
}
|
| 1026 |
+
return detail::unchecked_mutable_reference<T, Dims>(
|
| 1027 |
+
mutable_data(), shape(), strides(), ndim());
|
| 1028 |
+
}
|
| 1029 |
+
|
| 1030 |
+
/**
|
| 1031 |
+
* Returns a proxy object that provides const access to the array's data without bounds or
|
| 1032 |
+
* dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
|
| 1033 |
+
* underlying array have the `writable` flag. Use with care: the array must not be destroyed
|
| 1034 |
+
* or reshaped for the duration of the returned object, and the caller must take care not to
|
| 1035 |
+
* access invalid dimensions or dimension indices.
|
| 1036 |
+
*/
|
| 1037 |
+
template <typename T, ssize_t Dims = -1>
|
| 1038 |
+
detail::unchecked_reference<T, Dims> unchecked() const & {
|
| 1039 |
+
if (Dims >= 0 && ndim() != Dims) {
|
| 1040 |
+
throw std::domain_error("array has incorrect number of dimensions: "
|
| 1041 |
+
+ std::to_string(ndim()) + "; expected "
|
| 1042 |
+
+ std::to_string(Dims));
|
| 1043 |
+
}
|
| 1044 |
+
return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
|
| 1045 |
+
}
|
| 1046 |
+
|
| 1047 |
+
/// Return a new view with all of the dimensions of length 1 removed
|
| 1048 |
+
array squeeze() {
|
| 1049 |
+
auto &api = detail::npy_api::get();
|
| 1050 |
+
return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
|
| 1051 |
+
}
|
| 1052 |
+
|
| 1053 |
+
/// Resize array to given shape
|
| 1054 |
+
/// If refcheck is true and more that one reference exist to this array
|
| 1055 |
+
/// then resize will succeed only if it makes a reshape, i.e. original size doesn't change
|
| 1056 |
+
void resize(ShapeContainer new_shape, bool refcheck = true) {
|
| 1057 |
+
detail::npy_api::PyArray_Dims d
|
| 1058 |
+
= {// Use reinterpret_cast for PyPy on Windows (remove if fixed, checked on 7.3.1)
|
| 1059 |
+
reinterpret_cast<Py_intptr_t *>(new_shape->data()),
|
| 1060 |
+
int(new_shape->size())};
|
| 1061 |
+
// try to resize, set ordering param to -1 cause it's not used anyway
|
| 1062 |
+
auto new_array = reinterpret_steal<object>(
|
| 1063 |
+
detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1));
|
| 1064 |
+
if (!new_array) {
|
| 1065 |
+
throw error_already_set();
|
| 1066 |
+
}
|
| 1067 |
+
if (isinstance<array>(new_array)) {
|
| 1068 |
+
*this = std::move(new_array);
|
| 1069 |
+
}
|
| 1070 |
+
}
|
| 1071 |
+
|
| 1072 |
+
/// Optional `order` parameter omitted, to be added as needed.
|
| 1073 |
+
array reshape(ShapeContainer new_shape) {
|
| 1074 |
+
detail::npy_api::PyArray_Dims d
|
| 1075 |
+
= {reinterpret_cast<Py_intptr_t *>(new_shape->data()), int(new_shape->size())};
|
| 1076 |
+
auto new_array
|
| 1077 |
+
= reinterpret_steal<array>(detail::npy_api::get().PyArray_Newshape_(m_ptr, &d, 0));
|
| 1078 |
+
if (!new_array) {
|
| 1079 |
+
throw error_already_set();
|
| 1080 |
+
}
|
| 1081 |
+
return new_array;
|
| 1082 |
+
}
|
| 1083 |
+
|
| 1084 |
+
/// Create a view of an array in a different data type.
|
| 1085 |
+
/// This function may fundamentally reinterpret the data in the array.
|
| 1086 |
+
/// It is the responsibility of the caller to ensure that this is safe.
|
| 1087 |
+
/// Only supports the `dtype` argument, the `type` argument is omitted,
|
| 1088 |
+
/// to be added as needed.
|
| 1089 |
+
array view(const std::string &dtype) {
|
| 1090 |
+
auto &api = detail::npy_api::get();
|
| 1091 |
+
auto new_view = reinterpret_steal<array>(api.PyArray_View_(
|
| 1092 |
+
m_ptr, dtype::from_args(pybind11::str(dtype)).release().ptr(), nullptr));
|
| 1093 |
+
if (!new_view) {
|
| 1094 |
+
throw error_already_set();
|
| 1095 |
+
}
|
| 1096 |
+
return new_view;
|
| 1097 |
+
}
|
| 1098 |
+
|
| 1099 |
+
/// Ensure that the argument is a NumPy array
|
| 1100 |
+
/// In case of an error, nullptr is returned and the Python error is cleared.
|
| 1101 |
+
static array ensure(handle h, int ExtraFlags = 0) {
|
| 1102 |
+
auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
|
| 1103 |
+
if (!result) {
|
| 1104 |
+
PyErr_Clear();
|
| 1105 |
+
}
|
| 1106 |
+
return result;
|
| 1107 |
+
}
|
| 1108 |
+
|
| 1109 |
+
protected:
|
| 1110 |
+
template <typename, typename>
|
| 1111 |
+
friend struct detail::npy_format_descriptor;
|
| 1112 |
+
|
| 1113 |
+
void fail_dim_check(ssize_t dim, const std::string &msg) const {
|
| 1114 |
+
throw index_error(msg + ": " + std::to_string(dim) + " (ndim = " + std::to_string(ndim())
|
| 1115 |
+
+ ')');
|
| 1116 |
+
}
|
| 1117 |
+
|
| 1118 |
+
template <typename... Ix>
|
| 1119 |
+
ssize_t byte_offset(Ix... index) const {
|
| 1120 |
+
check_dimensions(index...);
|
| 1121 |
+
return detail::byte_offset_unsafe(strides(), ssize_t(index)...);
|
| 1122 |
+
}
|
| 1123 |
+
|
| 1124 |
+
void check_writeable() const {
|
| 1125 |
+
if (!writeable()) {
|
| 1126 |
+
throw std::domain_error("array is not writeable");
|
| 1127 |
+
}
|
| 1128 |
+
}
|
| 1129 |
+
|
| 1130 |
+
template <typename... Ix>
|
| 1131 |
+
void check_dimensions(Ix... index) const {
|
| 1132 |
+
check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...);
|
| 1133 |
+
}
|
| 1134 |
+
|
| 1135 |
+
void check_dimensions_impl(ssize_t, const ssize_t *) const {}
|
| 1136 |
+
|
| 1137 |
+
template <typename... Ix>
|
| 1138 |
+
void check_dimensions_impl(ssize_t axis, const ssize_t *shape, ssize_t i, Ix... index) const {
|
| 1139 |
+
if (i >= *shape) {
|
| 1140 |
+
throw index_error(std::string("index ") + std::to_string(i)
|
| 1141 |
+
+ " is out of bounds for axis " + std::to_string(axis)
|
| 1142 |
+
+ " with size " + std::to_string(*shape));
|
| 1143 |
+
}
|
| 1144 |
+
check_dimensions_impl(axis + 1, shape + 1, index...);
|
| 1145 |
+
}
|
| 1146 |
+
|
| 1147 |
+
/// Create array from any object -- always returns a new reference
|
| 1148 |
+
static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
|
| 1149 |
+
if (ptr == nullptr) {
|
| 1150 |
+
set_error(PyExc_ValueError, "cannot create a pybind11::array from a nullptr");
|
| 1151 |
+
return nullptr;
|
| 1152 |
+
}
|
| 1153 |
+
return detail::npy_api::get().PyArray_FromAny_(
|
| 1154 |
+
ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
|
| 1155 |
+
}
|
| 1156 |
+
};
|
| 1157 |
+
|
| 1158 |
+
template <typename T, int ExtraFlags = array::forcecast>
|
| 1159 |
+
class array_t : public array {
|
| 1160 |
+
private:
|
| 1161 |
+
struct private_ctor {};
|
| 1162 |
+
// Delegating constructor needed when both moving and accessing in the same constructor
|
| 1163 |
+
array_t(private_ctor,
|
| 1164 |
+
ShapeContainer &&shape,
|
| 1165 |
+
StridesContainer &&strides,
|
| 1166 |
+
const T *ptr,
|
| 1167 |
+
handle base)
|
| 1168 |
+
: array(std::move(shape), std::move(strides), ptr, base) {}
|
| 1169 |
+
|
| 1170 |
+
public:
|
| 1171 |
+
static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");
|
| 1172 |
+
|
| 1173 |
+
using value_type = T;
|
| 1174 |
+
|
| 1175 |
+
array_t() : array(0, static_cast<const T *>(nullptr)) {}
|
| 1176 |
+
array_t(handle h, borrowed_t) : array(h, borrowed_t{}) {}
|
| 1177 |
+
array_t(handle h, stolen_t) : array(h, stolen_t{}) {}
|
| 1178 |
+
|
| 1179 |
+
PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
|
| 1180 |
+
array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) {
|
| 1181 |
+
if (!m_ptr) {
|
| 1182 |
+
PyErr_Clear();
|
| 1183 |
+
}
|
| 1184 |
+
if (!is_borrowed) {
|
| 1185 |
+
Py_XDECREF(h.ptr());
|
| 1186 |
+
}
|
| 1187 |
+
}
|
| 1188 |
+
|
| 1189 |
+
// NOLINTNEXTLINE(google-explicit-constructor)
|
| 1190 |
+
array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) {
|
| 1191 |
+
if (!m_ptr) {
|
| 1192 |
+
throw error_already_set();
|
| 1193 |
+
}
|
| 1194 |
+
}
|
| 1195 |
+
|
| 1196 |
+
explicit array_t(const buffer_info &info, handle base = handle()) : array(info, base) {}
|
| 1197 |
+
|
| 1198 |
+
array_t(ShapeContainer shape,
|
| 1199 |
+
StridesContainer strides,
|
| 1200 |
+
const T *ptr = nullptr,
|
| 1201 |
+
handle base = handle())
|
| 1202 |
+
: array(std::move(shape), std::move(strides), ptr, base) {}
|
| 1203 |
+
|
| 1204 |
+
explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
|
| 1205 |
+
: array_t(private_ctor{},
|
| 1206 |
+
std::move(shape),
|
| 1207 |
+
(ExtraFlags & f_style) != 0 ? detail::f_strides(*shape, itemsize())
|
| 1208 |
+
: detail::c_strides(*shape, itemsize()),
|
| 1209 |
+
ptr,
|
| 1210 |
+
base) {}
|
| 1211 |
+
|
| 1212 |
+
explicit array_t(ssize_t count, const T *ptr = nullptr, handle base = handle())
|
| 1213 |
+
: array({count}, {}, ptr, base) {}
|
| 1214 |
+
|
| 1215 |
+
constexpr ssize_t itemsize() const { return sizeof(T); }
|
| 1216 |
+
|
| 1217 |
+
template <typename... Ix>
|
| 1218 |
+
ssize_t index_at(Ix... index) const {
|
| 1219 |
+
return offset_at(index...) / itemsize();
|
| 1220 |
+
}
|
| 1221 |
+
|
| 1222 |
+
template <typename... Ix>
|
| 1223 |
+
const T *data(Ix... index) const {
|
| 1224 |
+
return static_cast<const T *>(array::data(index...));
|
| 1225 |
+
}
|
| 1226 |
+
|
| 1227 |
+
template <typename... Ix>
|
| 1228 |
+
T *mutable_data(Ix... index) {
|
| 1229 |
+
return static_cast<T *>(array::mutable_data(index...));
|
| 1230 |
+
}
|
| 1231 |
+
|
| 1232 |
+
// Reference to element at a given index
|
| 1233 |
+
template <typename... Ix>
|
| 1234 |
+
const T &at(Ix... index) const {
|
| 1235 |
+
if ((ssize_t) sizeof...(index) != ndim()) {
|
| 1236 |
+
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
| 1237 |
+
}
|
| 1238 |
+
return *(static_cast<const T *>(array::data())
|
| 1239 |
+
+ byte_offset(ssize_t(index)...) / itemsize());
|
| 1240 |
+
}
|
| 1241 |
+
|
| 1242 |
+
// Mutable reference to element at a given index
|
| 1243 |
+
template <typename... Ix>
|
| 1244 |
+
T &mutable_at(Ix... index) {
|
| 1245 |
+
if ((ssize_t) sizeof...(index) != ndim()) {
|
| 1246 |
+
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
| 1247 |
+
}
|
| 1248 |
+
return *(static_cast<T *>(array::mutable_data())
|
| 1249 |
+
+ byte_offset(ssize_t(index)...) / itemsize());
|
| 1250 |
+
}
|
| 1251 |
+
|
| 1252 |
+
/**
|
| 1253 |
+
* Returns a proxy object that provides access to the array's data without bounds or
|
| 1254 |
+
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
|
| 1255 |
+
* care: the array must not be destroyed or reshaped for the duration of the returned object,
|
| 1256 |
+
* and the caller must take care not to access invalid dimensions or dimension indices.
|
| 1257 |
+
*/
|
| 1258 |
+
template <ssize_t Dims = -1>
|
| 1259 |
+
detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() & {
|
| 1260 |
+
return array::mutable_unchecked<T, Dims>();
|
| 1261 |
+
}
|
| 1262 |
+
|
| 1263 |
+
/**
|
| 1264 |
+
* Returns a proxy object that provides const access to the array's data without bounds or
|
| 1265 |
+
* dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
|
| 1266 |
+
* underlying array have the `writable` flag. Use with care: the array must not be destroyed
|
| 1267 |
+
* or reshaped for the duration of the returned object, and the caller must take care not to
|
| 1268 |
+
* access invalid dimensions or dimension indices.
|
| 1269 |
+
*/
|
| 1270 |
+
template <ssize_t Dims = -1>
|
| 1271 |
+
detail::unchecked_reference<T, Dims> unchecked() const & {
|
| 1272 |
+
return array::unchecked<T, Dims>();
|
| 1273 |
+
}
|
| 1274 |
+
|
| 1275 |
+
/// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
|
| 1276 |
+
/// it). In case of an error, nullptr is returned and the Python error is cleared.
|
| 1277 |
+
static array_t ensure(handle h) {
|
| 1278 |
+
auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
|
| 1279 |
+
if (!result) {
|
| 1280 |
+
PyErr_Clear();
|
| 1281 |
+
}
|
| 1282 |
+
return result;
|
| 1283 |
+
}
|
| 1284 |
+
|
| 1285 |
+
static bool check_(handle h) {
|
| 1286 |
+
const auto &api = detail::npy_api::get();
|
| 1287 |
+
return api.PyArray_Check_(h.ptr())
|
| 1288 |
+
&& api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr,
|
| 1289 |
+
dtype::of<T>().ptr())
|
| 1290 |
+
&& detail::check_flags(h.ptr(), ExtraFlags & (array::c_style | array::f_style));
|
| 1291 |
+
}
|
| 1292 |
+
|
| 1293 |
+
protected:
|
| 1294 |
+
/// Create array from any object -- always returns a new reference
|
| 1295 |
+
static PyObject *raw_array_t(PyObject *ptr) {
|
| 1296 |
+
if (ptr == nullptr) {
|
| 1297 |
+
set_error(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr");
|
| 1298 |
+
return nullptr;
|
| 1299 |
+
}
|
| 1300 |
+
return detail::npy_api::get().PyArray_FromAny_(ptr,
|
| 1301 |
+
dtype::of<T>().release().ptr(),
|
| 1302 |
+
0,
|
| 1303 |
+
0,
|
| 1304 |
+
detail::npy_api::NPY_ARRAY_ENSUREARRAY_
|
| 1305 |
+
| ExtraFlags,
|
| 1306 |
+
nullptr);
|
| 1307 |
+
}
|
| 1308 |
+
};
|
| 1309 |
+
|
| 1310 |
+
template <typename T>
|
| 1311 |
+
struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
|
| 1312 |
+
static std::string format() {
|
| 1313 |
+
return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
|
| 1314 |
+
}
|
| 1315 |
+
};
|
| 1316 |
+
|
| 1317 |
+
template <size_t N>
|
| 1318 |
+
struct format_descriptor<char[N]> {
|
| 1319 |
+
static std::string format() { return std::to_string(N) + 's'; }
|
| 1320 |
+
};
|
| 1321 |
+
template <size_t N>
|
| 1322 |
+
struct format_descriptor<std::array<char, N>> {
|
| 1323 |
+
static std::string format() { return std::to_string(N) + 's'; }
|
| 1324 |
+
};
|
| 1325 |
+
|
| 1326 |
+
template <typename T>
|
| 1327 |
+
struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
|
| 1328 |
+
static std::string format() {
|
| 1329 |
+
return format_descriptor<
|
| 1330 |
+
typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
|
| 1331 |
+
}
|
| 1332 |
+
};
|
| 1333 |
+
|
| 1334 |
+
template <typename T>
|
| 1335 |
+
struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
|
| 1336 |
+
static std::string format() {
|
| 1337 |
+
using namespace detail;
|
| 1338 |
+
static constexpr auto extents = const_name("(") + array_info<T>::extents + const_name(")");
|
| 1339 |
+
return extents.text + format_descriptor<remove_all_extents_t<T>>::format();
|
| 1340 |
+
}
|
| 1341 |
+
};
|
| 1342 |
+
|
| 1343 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 1344 |
+
template <typename T, int ExtraFlags>
|
| 1345 |
+
struct pyobject_caster<array_t<T, ExtraFlags>> {
|
| 1346 |
+
using type = array_t<T, ExtraFlags>;
|
| 1347 |
+
|
| 1348 |
+
bool load(handle src, bool convert) {
|
| 1349 |
+
if (!convert && !type::check_(src)) {
|
| 1350 |
+
return false;
|
| 1351 |
+
}
|
| 1352 |
+
value = type::ensure(src);
|
| 1353 |
+
return static_cast<bool>(value);
|
| 1354 |
+
}
|
| 1355 |
+
|
| 1356 |
+
static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
|
| 1357 |
+
return src.inc_ref();
|
| 1358 |
+
}
|
| 1359 |
+
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name);
|
| 1360 |
+
};
|
| 1361 |
+
|
| 1362 |
+
template <typename T>
|
| 1363 |
+
struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
|
| 1364 |
+
static bool compare(const buffer_info &b) {
|
| 1365 |
+
return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
|
| 1366 |
+
}
|
| 1367 |
+
};
|
| 1368 |
+
|
| 1369 |
+
template <typename T, typename = void>
|
| 1370 |
+
struct npy_format_descriptor_name;
|
| 1371 |
+
|
| 1372 |
+
template <typename T>
|
| 1373 |
+
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
|
| 1374 |
+
static constexpr auto name = const_name<std::is_same<T, bool>::value>(
|
| 1375 |
+
const_name("bool"),
|
| 1376 |
+
const_name<std::is_signed<T>::value>("numpy.int", "numpy.uint")
|
| 1377 |
+
+ const_name<sizeof(T) * 8>());
|
| 1378 |
+
};
|
| 1379 |
+
|
| 1380 |
+
template <typename T>
|
| 1381 |
+
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
|
| 1382 |
+
static constexpr auto name = const_name < std::is_same<T, float>::value
|
| 1383 |
+
|| std::is_same<T, const float>::value
|
| 1384 |
+
|| std::is_same<T, double>::value
|
| 1385 |
+
|| std::is_same<T, const double>::value
|
| 1386 |
+
> (const_name("numpy.float") + const_name<sizeof(T) * 8>(),
|
| 1387 |
+
const_name("numpy.longdouble"));
|
| 1388 |
+
};
|
| 1389 |
+
|
| 1390 |
+
template <typename T>
|
| 1391 |
+
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
|
| 1392 |
+
static constexpr auto name = const_name < std::is_same<typename T::value_type, float>::value
|
| 1393 |
+
|| std::is_same<typename T::value_type, const float>::value
|
| 1394 |
+
|| std::is_same<typename T::value_type, double>::value
|
| 1395 |
+
|| std::is_same<typename T::value_type, const double>::value
|
| 1396 |
+
> (const_name("numpy.complex")
|
| 1397 |
+
+ const_name<sizeof(typename T::value_type) * 16>(),
|
| 1398 |
+
const_name("numpy.longcomplex"));
|
| 1399 |
+
};
|
| 1400 |
+
|
| 1401 |
+
template <typename T>
|
| 1402 |
+
struct npy_format_descriptor<
|
| 1403 |
+
T,
|
| 1404 |
+
enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>>
|
| 1405 |
+
: npy_format_descriptor_name<T> {
|
| 1406 |
+
private:
|
| 1407 |
+
// NB: the order here must match the one in common.h
|
| 1408 |
+
constexpr static const int values[15] = {npy_api::NPY_BOOL_,
|
| 1409 |
+
npy_api::NPY_BYTE_,
|
| 1410 |
+
npy_api::NPY_UBYTE_,
|
| 1411 |
+
npy_api::NPY_INT16_,
|
| 1412 |
+
npy_api::NPY_UINT16_,
|
| 1413 |
+
npy_api::NPY_INT32_,
|
| 1414 |
+
npy_api::NPY_UINT32_,
|
| 1415 |
+
npy_api::NPY_INT64_,
|
| 1416 |
+
npy_api::NPY_UINT64_,
|
| 1417 |
+
npy_api::NPY_FLOAT_,
|
| 1418 |
+
npy_api::NPY_DOUBLE_,
|
| 1419 |
+
npy_api::NPY_LONGDOUBLE_,
|
| 1420 |
+
npy_api::NPY_CFLOAT_,
|
| 1421 |
+
npy_api::NPY_CDOUBLE_,
|
| 1422 |
+
npy_api::NPY_CLONGDOUBLE_};
|
| 1423 |
+
|
| 1424 |
+
public:
|
| 1425 |
+
static constexpr int value = values[detail::is_fmt_numeric<T>::index];
|
| 1426 |
+
|
| 1427 |
+
static pybind11::dtype dtype() { return pybind11::dtype(/*typenum*/ value); }
|
| 1428 |
+
};
|
| 1429 |
+
|
| 1430 |
+
template <typename T>
|
| 1431 |
+
struct npy_format_descriptor<T, enable_if_t<is_same_ignoring_cvref<T, PyObject *>::value>> {
|
| 1432 |
+
static constexpr auto name = const_name("object");
|
| 1433 |
+
|
| 1434 |
+
static constexpr int value = npy_api::NPY_OBJECT_;
|
| 1435 |
+
|
| 1436 |
+
static pybind11::dtype dtype() { return pybind11::dtype(/*typenum*/ value); }
|
| 1437 |
+
};
|
| 1438 |
+
|
| 1439 |
+
#define PYBIND11_DECL_CHAR_FMT \
|
| 1440 |
+
static constexpr auto name = const_name("S") + const_name<N>(); \
|
| 1441 |
+
static pybind11::dtype dtype() { \
|
| 1442 |
+
return pybind11::dtype(std::string("S") + std::to_string(N)); \
|
| 1443 |
+
}
|
| 1444 |
+
template <size_t N>
|
| 1445 |
+
struct npy_format_descriptor<char[N]> {
|
| 1446 |
+
PYBIND11_DECL_CHAR_FMT
|
| 1447 |
+
};
|
| 1448 |
+
template <size_t N>
|
| 1449 |
+
struct npy_format_descriptor<std::array<char, N>> {
|
| 1450 |
+
PYBIND11_DECL_CHAR_FMT
|
| 1451 |
+
};
|
| 1452 |
+
#undef PYBIND11_DECL_CHAR_FMT
|
| 1453 |
+
|
| 1454 |
+
template <typename T>
|
| 1455 |
+
struct npy_format_descriptor<T, enable_if_t<array_info<T>::is_array>> {
|
| 1456 |
+
private:
|
| 1457 |
+
using base_descr = npy_format_descriptor<typename array_info<T>::type>;
|
| 1458 |
+
|
| 1459 |
+
public:
|
| 1460 |
+
static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");
|
| 1461 |
+
|
| 1462 |
+
static constexpr auto name
|
| 1463 |
+
= const_name("(") + array_info<T>::extents + const_name(")") + base_descr::name;
|
| 1464 |
+
static pybind11::dtype dtype() {
|
| 1465 |
+
list shape;
|
| 1466 |
+
array_info<T>::append_extents(shape);
|
| 1467 |
+
return pybind11::dtype::from_args(
|
| 1468 |
+
pybind11::make_tuple(base_descr::dtype(), std::move(shape)));
|
| 1469 |
+
}
|
| 1470 |
+
};
|
| 1471 |
+
|
| 1472 |
+
template <typename T>
|
| 1473 |
+
struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
|
| 1474 |
+
private:
|
| 1475 |
+
using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
|
| 1476 |
+
|
| 1477 |
+
public:
|
| 1478 |
+
static constexpr auto name = base_descr::name;
|
| 1479 |
+
static pybind11::dtype dtype() { return base_descr::dtype(); }
|
| 1480 |
+
};
|
| 1481 |
+
|
| 1482 |
+
struct field_descriptor {
|
| 1483 |
+
const char *name;
|
| 1484 |
+
ssize_t offset;
|
| 1485 |
+
ssize_t size;
|
| 1486 |
+
std::string format;
|
| 1487 |
+
dtype descr;
|
| 1488 |
+
};
|
| 1489 |
+
|
| 1490 |
+
PYBIND11_NOINLINE void register_structured_dtype(any_container<field_descriptor> fields,
|
| 1491 |
+
const std::type_info &tinfo,
|
| 1492 |
+
ssize_t itemsize,
|
| 1493 |
+
bool (*direct_converter)(PyObject *, void *&)) {
|
| 1494 |
+
|
| 1495 |
+
auto &numpy_internals = get_numpy_internals();
|
| 1496 |
+
if (numpy_internals.get_type_info(tinfo, false)) {
|
| 1497 |
+
pybind11_fail("NumPy: dtype is already registered");
|
| 1498 |
+
}
|
| 1499 |
+
|
| 1500 |
+
// Use ordered fields because order matters as of NumPy 1.14:
|
| 1501 |
+
// https://docs.scipy.org/doc/numpy/release.html#multiple-field-indexing-assignment-of-structured-arrays
|
| 1502 |
+
std::vector<field_descriptor> ordered_fields(std::move(fields));
|
| 1503 |
+
std::sort(
|
| 1504 |
+
ordered_fields.begin(),
|
| 1505 |
+
ordered_fields.end(),
|
| 1506 |
+
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
|
| 1507 |
+
|
| 1508 |
+
list names, formats, offsets;
|
| 1509 |
+
for (auto &field : ordered_fields) {
|
| 1510 |
+
if (!field.descr) {
|
| 1511 |
+
pybind11_fail(std::string("NumPy: unsupported field dtype: `") + field.name + "` @ "
|
| 1512 |
+
+ tinfo.name());
|
| 1513 |
+
}
|
| 1514 |
+
names.append(pybind11::str(field.name));
|
| 1515 |
+
formats.append(field.descr);
|
| 1516 |
+
offsets.append(pybind11::int_(field.offset));
|
| 1517 |
+
}
|
| 1518 |
+
auto *dtype_ptr
|
| 1519 |
+
= pybind11::dtype(std::move(names), std::move(formats), std::move(offsets), itemsize)
|
| 1520 |
+
.release()
|
| 1521 |
+
.ptr();
|
| 1522 |
+
|
| 1523 |
+
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
|
| 1524 |
+
// not encoded explicitly into the format string. This will supposedly
|
| 1525 |
+
// get fixed in v1.12; for further details, see these:
|
| 1526 |
+
// - https://github.com/numpy/numpy/issues/7797
|
| 1527 |
+
// - https://github.com/numpy/numpy/pull/7798
|
| 1528 |
+
// Because of this, we won't use numpy's logic to generate buffer format
|
| 1529 |
+
// strings and will just do it ourselves.
|
| 1530 |
+
ssize_t offset = 0;
|
| 1531 |
+
std::ostringstream oss;
|
| 1532 |
+
// mark the structure as unaligned with '^', because numpy and C++ don't
|
| 1533 |
+
// always agree about alignment (particularly for complex), and we're
|
| 1534 |
+
// explicitly listing all our padding. This depends on none of the fields
|
| 1535 |
+
// overriding the endianness. Putting the ^ in front of individual fields
|
| 1536 |
+
// isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049
|
| 1537 |
+
oss << "^T{";
|
| 1538 |
+
for (auto &field : ordered_fields) {
|
| 1539 |
+
if (field.offset > offset) {
|
| 1540 |
+
oss << (field.offset - offset) << 'x';
|
| 1541 |
+
}
|
| 1542 |
+
oss << field.format << ':' << field.name << ':';
|
| 1543 |
+
offset = field.offset + field.size;
|
| 1544 |
+
}
|
| 1545 |
+
if (itemsize > offset) {
|
| 1546 |
+
oss << (itemsize - offset) << 'x';
|
| 1547 |
+
}
|
| 1548 |
+
oss << '}';
|
| 1549 |
+
auto format_str = oss.str();
|
| 1550 |
+
|
| 1551 |
+
// Smoke test: verify that NumPy properly parses our buffer format string
|
| 1552 |
+
auto &api = npy_api::get();
|
| 1553 |
+
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
|
| 1554 |
+
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) {
|
| 1555 |
+
pybind11_fail("NumPy: invalid buffer descriptor!");
|
| 1556 |
+
}
|
| 1557 |
+
|
| 1558 |
+
auto tindex = std::type_index(tinfo);
|
| 1559 |
+
numpy_internals.registered_dtypes[tindex] = {dtype_ptr, std::move(format_str)};
|
| 1560 |
+
with_internals([tindex, &direct_converter](internals &internals) {
|
| 1561 |
+
internals.direct_conversions[tindex].push_back(direct_converter);
|
| 1562 |
+
});
|
| 1563 |
+
}
|
| 1564 |
+
|
| 1565 |
+
template <typename T, typename SFINAE>
|
| 1566 |
+
struct npy_format_descriptor {
|
| 1567 |
+
static_assert(is_pod_struct<T>::value,
|
| 1568 |
+
"Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
|
| 1569 |
+
|
| 1570 |
+
static constexpr auto name = make_caster<T>::name;
|
| 1571 |
+
|
| 1572 |
+
static pybind11::dtype dtype() { return reinterpret_borrow<pybind11::dtype>(dtype_ptr()); }
|
| 1573 |
+
|
| 1574 |
+
static std::string format() {
|
| 1575 |
+
static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
|
| 1576 |
+
return format_str;
|
| 1577 |
+
}
|
| 1578 |
+
|
| 1579 |
+
static void register_dtype(any_container<field_descriptor> fields) {
|
| 1580 |
+
register_structured_dtype(std::move(fields),
|
| 1581 |
+
typeid(typename std::remove_cv<T>::type),
|
| 1582 |
+
sizeof(T),
|
| 1583 |
+
&direct_converter);
|
| 1584 |
+
}
|
| 1585 |
+
|
| 1586 |
+
private:
|
| 1587 |
+
static PyObject *dtype_ptr() {
|
| 1588 |
+
static PyObject *ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
|
| 1589 |
+
return ptr;
|
| 1590 |
+
}
|
| 1591 |
+
|
| 1592 |
+
static bool direct_converter(PyObject *obj, void *&value) {
|
| 1593 |
+
auto &api = npy_api::get();
|
| 1594 |
+
if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) {
|
| 1595 |
+
return false;
|
| 1596 |
+
}
|
| 1597 |
+
if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
|
| 1598 |
+
if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
|
| 1599 |
+
value = ((PyVoidScalarObject_Proxy *) obj)->obval;
|
| 1600 |
+
return true;
|
| 1601 |
+
}
|
| 1602 |
+
}
|
| 1603 |
+
return false;
|
| 1604 |
+
}
|
| 1605 |
+
};
|
| 1606 |
+
|
| 1607 |
+
#ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code)
|
| 1608 |
+
# define PYBIND11_NUMPY_DTYPE(Type, ...) ((void) 0)
|
| 1609 |
+
# define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void) 0)
|
| 1610 |
+
#else
|
| 1611 |
+
|
| 1612 |
+
# define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
|
| 1613 |
+
::pybind11::detail::field_descriptor { \
|
| 1614 |
+
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
|
| 1615 |
+
::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
|
| 1616 |
+
::pybind11::detail::npy_format_descriptor< \
|
| 1617 |
+
decltype(std::declval<T>().Field)>::dtype() \
|
| 1618 |
+
}
|
| 1619 |
+
|
| 1620 |
+
// Extract name, offset and format descriptor for a struct field
|
| 1621 |
+
# define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)
|
| 1622 |
+
|
| 1623 |
+
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
|
| 1624 |
+
// (C) William Swanson, Paul Fultz
|
| 1625 |
+
# define PYBIND11_EVAL0(...) __VA_ARGS__
|
| 1626 |
+
# define PYBIND11_EVAL1(...) PYBIND11_EVAL0(PYBIND11_EVAL0(PYBIND11_EVAL0(__VA_ARGS__)))
|
| 1627 |
+
# define PYBIND11_EVAL2(...) PYBIND11_EVAL1(PYBIND11_EVAL1(PYBIND11_EVAL1(__VA_ARGS__)))
|
| 1628 |
+
# define PYBIND11_EVAL3(...) PYBIND11_EVAL2(PYBIND11_EVAL2(PYBIND11_EVAL2(__VA_ARGS__)))
|
| 1629 |
+
# define PYBIND11_EVAL4(...) PYBIND11_EVAL3(PYBIND11_EVAL3(PYBIND11_EVAL3(__VA_ARGS__)))
|
| 1630 |
+
# define PYBIND11_EVAL(...) PYBIND11_EVAL4(PYBIND11_EVAL4(PYBIND11_EVAL4(__VA_ARGS__)))
|
| 1631 |
+
# define PYBIND11_MAP_END(...)
|
| 1632 |
+
# define PYBIND11_MAP_OUT
|
| 1633 |
+
# define PYBIND11_MAP_COMMA ,
|
| 1634 |
+
# define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
|
| 1635 |
+
# define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
|
| 1636 |
+
# define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0(test, next, 0)
|
| 1637 |
+
# define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1(PYBIND11_MAP_GET_END test, next)
|
| 1638 |
+
# if defined(_MSC_VER) \
|
| 1639 |
+
&& !defined(__clang__) // MSVC is not as eager to expand macros, hence this workaround
|
| 1640 |
+
# define PYBIND11_MAP_LIST_NEXT1(test, next) \
|
| 1641 |
+
PYBIND11_EVAL0(PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0))
|
| 1642 |
+
# else
|
| 1643 |
+
# define PYBIND11_MAP_LIST_NEXT1(test, next) \
|
| 1644 |
+
PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0)
|
| 1645 |
+
# endif
|
| 1646 |
+
# define PYBIND11_MAP_LIST_NEXT(test, next) \
|
| 1647 |
+
PYBIND11_MAP_LIST_NEXT1(PYBIND11_MAP_GET_END test, next)
|
| 1648 |
+
# define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
|
| 1649 |
+
f(t, x) PYBIND11_MAP_LIST_NEXT(peek, PYBIND11_MAP_LIST1)(f, t, peek, __VA_ARGS__)
|
| 1650 |
+
# define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
|
| 1651 |
+
f(t, x) PYBIND11_MAP_LIST_NEXT(peek, PYBIND11_MAP_LIST0)(f, t, peek, __VA_ARGS__)
|
| 1652 |
+
// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
|
| 1653 |
+
# define PYBIND11_MAP_LIST(f, t, ...) \
|
| 1654 |
+
PYBIND11_EVAL(PYBIND11_MAP_LIST1(f, t, __VA_ARGS__, (), 0))
|
| 1655 |
+
|
| 1656 |
+
# define PYBIND11_NUMPY_DTYPE(Type, ...) \
|
| 1657 |
+
::pybind11::detail::npy_format_descriptor<Type>::register_dtype( \
|
| 1658 |
+
::std::vector<::pybind11::detail::field_descriptor>{ \
|
| 1659 |
+
PYBIND11_MAP_LIST(PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
|
| 1660 |
+
|
| 1661 |
+
# if defined(_MSC_VER) && !defined(__clang__)
|
| 1662 |
+
# define PYBIND11_MAP2_LIST_NEXT1(test, next) \
|
| 1663 |
+
PYBIND11_EVAL0(PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0))
|
| 1664 |
+
# else
|
| 1665 |
+
# define PYBIND11_MAP2_LIST_NEXT1(test, next) \
|
| 1666 |
+
PYBIND11_MAP_NEXT0(test, PYBIND11_MAP_COMMA next, 0)
|
| 1667 |
+
# endif
|
| 1668 |
+
# define PYBIND11_MAP2_LIST_NEXT(test, next) \
|
| 1669 |
+
PYBIND11_MAP2_LIST_NEXT1(PYBIND11_MAP_GET_END test, next)
|
| 1670 |
+
# define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
|
| 1671 |
+
f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT(peek, PYBIND11_MAP2_LIST1)(f, t, peek, __VA_ARGS__)
|
| 1672 |
+
# define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
|
| 1673 |
+
f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT(peek, PYBIND11_MAP2_LIST0)(f, t, peek, __VA_ARGS__)
|
| 1674 |
+
// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
|
| 1675 |
+
# define PYBIND11_MAP2_LIST(f, t, ...) \
|
| 1676 |
+
PYBIND11_EVAL(PYBIND11_MAP2_LIST1(f, t, __VA_ARGS__, (), 0))
|
| 1677 |
+
|
| 1678 |
+
# define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
|
| 1679 |
+
::pybind11::detail::npy_format_descriptor<Type>::register_dtype( \
|
| 1680 |
+
::std::vector<::pybind11::detail::field_descriptor>{ \
|
| 1681 |
+
PYBIND11_MAP2_LIST(PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
|
| 1682 |
+
|
| 1683 |
+
#endif // __CLION_IDE__
|
| 1684 |
+
|
| 1685 |
+
class common_iterator {
|
| 1686 |
+
public:
|
| 1687 |
+
using container_type = std::vector<ssize_t>;
|
| 1688 |
+
using value_type = container_type::value_type;
|
| 1689 |
+
using size_type = container_type::size_type;
|
| 1690 |
+
|
| 1691 |
+
common_iterator() : m_strides() {}
|
| 1692 |
+
|
| 1693 |
+
common_iterator(void *ptr, const container_type &strides, const container_type &shape)
|
| 1694 |
+
: p_ptr(reinterpret_cast<char *>(ptr)), m_strides(strides.size()) {
|
| 1695 |
+
m_strides.back() = static_cast<value_type>(strides.back());
|
| 1696 |
+
for (size_type i = m_strides.size() - 1; i != 0; --i) {
|
| 1697 |
+
size_type j = i - 1;
|
| 1698 |
+
auto s = static_cast<value_type>(shape[i]);
|
| 1699 |
+
m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
|
| 1700 |
+
}
|
| 1701 |
+
}
|
| 1702 |
+
|
| 1703 |
+
void increment(size_type dim) { p_ptr += m_strides[dim]; }
|
| 1704 |
+
|
| 1705 |
+
void *data() const { return p_ptr; }
|
| 1706 |
+
|
| 1707 |
+
private:
|
| 1708 |
+
char *p_ptr{nullptr};
|
| 1709 |
+
container_type m_strides;
|
| 1710 |
+
};
|
| 1711 |
+
|
| 1712 |
+
template <size_t N>
|
| 1713 |
+
class multi_array_iterator {
|
| 1714 |
+
public:
|
| 1715 |
+
using container_type = std::vector<ssize_t>;
|
| 1716 |
+
|
| 1717 |
+
multi_array_iterator(const std::array<buffer_info, N> &buffers, const container_type &shape)
|
| 1718 |
+
: m_shape(shape.size()), m_index(shape.size(), 0), m_common_iterator() {
|
| 1719 |
+
|
| 1720 |
+
// Manual copy to avoid conversion warning if using std::copy
|
| 1721 |
+
for (size_t i = 0; i < shape.size(); ++i) {
|
| 1722 |
+
m_shape[i] = shape[i];
|
| 1723 |
+
}
|
| 1724 |
+
|
| 1725 |
+
container_type strides(shape.size());
|
| 1726 |
+
for (size_t i = 0; i < N; ++i) {
|
| 1727 |
+
init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
|
| 1728 |
+
}
|
| 1729 |
+
}
|
| 1730 |
+
|
| 1731 |
+
multi_array_iterator &operator++() {
|
| 1732 |
+
for (size_t j = m_index.size(); j != 0; --j) {
|
| 1733 |
+
size_t i = j - 1;
|
| 1734 |
+
if (++m_index[i] != m_shape[i]) {
|
| 1735 |
+
increment_common_iterator(i);
|
| 1736 |
+
break;
|
| 1737 |
+
}
|
| 1738 |
+
m_index[i] = 0;
|
| 1739 |
+
}
|
| 1740 |
+
return *this;
|
| 1741 |
+
}
|
| 1742 |
+
|
| 1743 |
+
template <size_t K, class T = void>
|
| 1744 |
+
T *data() const {
|
| 1745 |
+
return reinterpret_cast<T *>(m_common_iterator[K].data());
|
| 1746 |
+
}
|
| 1747 |
+
|
| 1748 |
+
private:
|
| 1749 |
+
using common_iter = common_iterator;
|
| 1750 |
+
|
| 1751 |
+
void init_common_iterator(const buffer_info &buffer,
|
| 1752 |
+
const container_type &shape,
|
| 1753 |
+
common_iter &iterator,
|
| 1754 |
+
container_type &strides) {
|
| 1755 |
+
auto buffer_shape_iter = buffer.shape.rbegin();
|
| 1756 |
+
auto buffer_strides_iter = buffer.strides.rbegin();
|
| 1757 |
+
auto shape_iter = shape.rbegin();
|
| 1758 |
+
auto strides_iter = strides.rbegin();
|
| 1759 |
+
|
| 1760 |
+
while (buffer_shape_iter != buffer.shape.rend()) {
|
| 1761 |
+
if (*shape_iter == *buffer_shape_iter) {
|
| 1762 |
+
*strides_iter = *buffer_strides_iter;
|
| 1763 |
+
} else {
|
| 1764 |
+
*strides_iter = 0;
|
| 1765 |
+
}
|
| 1766 |
+
|
| 1767 |
+
++buffer_shape_iter;
|
| 1768 |
+
++buffer_strides_iter;
|
| 1769 |
+
++shape_iter;
|
| 1770 |
+
++strides_iter;
|
| 1771 |
+
}
|
| 1772 |
+
|
| 1773 |
+
std::fill(strides_iter, strides.rend(), 0);
|
| 1774 |
+
iterator = common_iter(buffer.ptr, strides, shape);
|
| 1775 |
+
}
|
| 1776 |
+
|
| 1777 |
+
void increment_common_iterator(size_t dim) {
|
| 1778 |
+
for (auto &iter : m_common_iterator) {
|
| 1779 |
+
iter.increment(dim);
|
| 1780 |
+
}
|
| 1781 |
+
}
|
| 1782 |
+
|
| 1783 |
+
container_type m_shape;
|
| 1784 |
+
container_type m_index;
|
| 1785 |
+
std::array<common_iter, N> m_common_iterator;
|
| 1786 |
+
};
|
| 1787 |
+
|
| 1788 |
+
enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };
|
| 1789 |
+
|
| 1790 |
+
// Populates the shape and number of dimensions for the set of buffers. Returns a
|
| 1791 |
+
// broadcast_trivial enum value indicating whether the broadcast is "trivial"--that is, has each
|
| 1792 |
+
// buffer being either a singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous
|
| 1793 |
+
// (`f_trivial`) storage buffer; returns `non_trivial` otherwise.
|
| 1794 |
+
template <size_t N>
|
| 1795 |
+
broadcast_trivial
|
| 1796 |
+
broadcast(const std::array<buffer_info, N> &buffers, ssize_t &ndim, std::vector<ssize_t> &shape) {
|
| 1797 |
+
ndim = std::accumulate(
|
| 1798 |
+
buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) {
|
| 1799 |
+
return std::max(res, buf.ndim);
|
| 1800 |
+
});
|
| 1801 |
+
|
| 1802 |
+
shape.clear();
|
| 1803 |
+
shape.resize((size_t) ndim, 1);
|
| 1804 |
+
|
| 1805 |
+
// Figure out the output size, and make sure all input arrays conform (i.e. are either size 1
|
| 1806 |
+
// or the full size).
|
| 1807 |
+
for (size_t i = 0; i < N; ++i) {
|
| 1808 |
+
auto res_iter = shape.rbegin();
|
| 1809 |
+
auto end = buffers[i].shape.rend();
|
| 1810 |
+
for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end;
|
| 1811 |
+
++shape_iter, ++res_iter) {
|
| 1812 |
+
const auto &dim_size_in = *shape_iter;
|
| 1813 |
+
auto &dim_size_out = *res_iter;
|
| 1814 |
+
|
| 1815 |
+
// Each input dimension can either be 1 or `n`, but `n` values must match across
|
| 1816 |
+
// buffers
|
| 1817 |
+
if (dim_size_out == 1) {
|
| 1818 |
+
dim_size_out = dim_size_in;
|
| 1819 |
+
} else if (dim_size_in != 1 && dim_size_in != dim_size_out) {
|
| 1820 |
+
pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
|
| 1821 |
+
}
|
| 1822 |
+
}
|
| 1823 |
+
}
|
| 1824 |
+
|
| 1825 |
+
bool trivial_broadcast_c = true;
|
| 1826 |
+
bool trivial_broadcast_f = true;
|
| 1827 |
+
for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
|
| 1828 |
+
if (buffers[i].size == 1) {
|
| 1829 |
+
continue;
|
| 1830 |
+
}
|
| 1831 |
+
|
| 1832 |
+
// Require the same number of dimensions:
|
| 1833 |
+
if (buffers[i].ndim != ndim) {
|
| 1834 |
+
return broadcast_trivial::non_trivial;
|
| 1835 |
+
}
|
| 1836 |
+
|
| 1837 |
+
// Require all dimensions be full-size:
|
| 1838 |
+
if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin())) {
|
| 1839 |
+
return broadcast_trivial::non_trivial;
|
| 1840 |
+
}
|
| 1841 |
+
|
| 1842 |
+
// Check for C contiguity (but only if previous inputs were also C contiguous)
|
| 1843 |
+
if (trivial_broadcast_c) {
|
| 1844 |
+
ssize_t expect_stride = buffers[i].itemsize;
|
| 1845 |
+
auto end = buffers[i].shape.crend();
|
| 1846 |
+
for (auto shape_iter = buffers[i].shape.crbegin(),
|
| 1847 |
+
stride_iter = buffers[i].strides.crbegin();
|
| 1848 |
+
trivial_broadcast_c && shape_iter != end;
|
| 1849 |
+
++shape_iter, ++stride_iter) {
|
| 1850 |
+
if (expect_stride == *stride_iter) {
|
| 1851 |
+
expect_stride *= *shape_iter;
|
| 1852 |
+
} else {
|
| 1853 |
+
trivial_broadcast_c = false;
|
| 1854 |
+
}
|
| 1855 |
+
}
|
| 1856 |
+
}
|
| 1857 |
+
|
| 1858 |
+
// Check for Fortran contiguity (if previous inputs were also F contiguous)
|
| 1859 |
+
if (trivial_broadcast_f) {
|
| 1860 |
+
ssize_t expect_stride = buffers[i].itemsize;
|
| 1861 |
+
auto end = buffers[i].shape.cend();
|
| 1862 |
+
for (auto shape_iter = buffers[i].shape.cbegin(),
|
| 1863 |
+
stride_iter = buffers[i].strides.cbegin();
|
| 1864 |
+
trivial_broadcast_f && shape_iter != end;
|
| 1865 |
+
++shape_iter, ++stride_iter) {
|
| 1866 |
+
if (expect_stride == *stride_iter) {
|
| 1867 |
+
expect_stride *= *shape_iter;
|
| 1868 |
+
} else {
|
| 1869 |
+
trivial_broadcast_f = false;
|
| 1870 |
+
}
|
| 1871 |
+
}
|
| 1872 |
+
}
|
| 1873 |
+
}
|
| 1874 |
+
|
| 1875 |
+
return trivial_broadcast_c ? broadcast_trivial::c_trivial
|
| 1876 |
+
: trivial_broadcast_f ? broadcast_trivial::f_trivial
|
| 1877 |
+
: broadcast_trivial::non_trivial;
|
| 1878 |
+
}
|
| 1879 |
+
|
| 1880 |
+
template <typename T>
|
| 1881 |
+
struct vectorize_arg {
|
| 1882 |
+
static_assert(!std::is_rvalue_reference<T>::value,
|
| 1883 |
+
"Functions with rvalue reference arguments cannot be vectorized");
|
| 1884 |
+
// The wrapped function gets called with this type:
|
| 1885 |
+
using call_type = remove_reference_t<T>;
|
| 1886 |
+
// Is this a vectorized argument?
|
| 1887 |
+
static constexpr bool vectorize
|
| 1888 |
+
= satisfies_any_of<call_type, std::is_arithmetic, is_complex, is_pod>::value
|
| 1889 |
+
&& satisfies_none_of<call_type,
|
| 1890 |
+
std::is_pointer,
|
| 1891 |
+
std::is_array,
|
| 1892 |
+
is_std_array,
|
| 1893 |
+
std::is_enum>::value
|
| 1894 |
+
&& (!std::is_reference<T>::value
|
| 1895 |
+
|| (std::is_lvalue_reference<T>::value && std::is_const<call_type>::value));
|
| 1896 |
+
// Accept this type: an array for vectorized types, otherwise the type as-is:
|
| 1897 |
+
using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
|
| 1898 |
+
};
|
| 1899 |
+
|
| 1900 |
+
// py::vectorize when a return type is present
|
| 1901 |
+
template <typename Func, typename Return, typename... Args>
|
| 1902 |
+
struct vectorize_returned_array {
|
| 1903 |
+
using Type = array_t<Return>;
|
| 1904 |
+
|
| 1905 |
+
static Type create(broadcast_trivial trivial, const std::vector<ssize_t> &shape) {
|
| 1906 |
+
if (trivial == broadcast_trivial::f_trivial) {
|
| 1907 |
+
return array_t<Return, array::f_style>(shape);
|
| 1908 |
+
}
|
| 1909 |
+
return array_t<Return>(shape);
|
| 1910 |
+
}
|
| 1911 |
+
|
| 1912 |
+
static Return *mutable_data(Type &array) { return array.mutable_data(); }
|
| 1913 |
+
|
| 1914 |
+
static Return call(Func &f, Args &...args) { return f(args...); }
|
| 1915 |
+
|
| 1916 |
+
static void call(Return *out, size_t i, Func &f, Args &...args) { out[i] = f(args...); }
|
| 1917 |
+
};
|
| 1918 |
+
|
| 1919 |
+
// py::vectorize when a return type is not present
|
| 1920 |
+
template <typename Func, typename... Args>
|
| 1921 |
+
struct vectorize_returned_array<Func, void, Args...> {
|
| 1922 |
+
using Type = none;
|
| 1923 |
+
|
| 1924 |
+
static Type create(broadcast_trivial, const std::vector<ssize_t> &) { return none(); }
|
| 1925 |
+
|
| 1926 |
+
static void *mutable_data(Type &) { return nullptr; }
|
| 1927 |
+
|
| 1928 |
+
static detail::void_type call(Func &f, Args &...args) {
|
| 1929 |
+
f(args...);
|
| 1930 |
+
return {};
|
| 1931 |
+
}
|
| 1932 |
+
|
| 1933 |
+
static void call(void *, size_t, Func &f, Args &...args) { f(args...); }
|
| 1934 |
+
};
|
| 1935 |
+
|
| 1936 |
+
template <typename Func, typename Return, typename... Args>
|
| 1937 |
+
struct vectorize_helper {
|
| 1938 |
+
|
| 1939 |
+
// NVCC for some reason breaks if NVectorized is private
|
| 1940 |
+
#ifdef __CUDACC__
|
| 1941 |
+
public:
|
| 1942 |
+
#else
|
| 1943 |
+
private:
|
| 1944 |
+
#endif
|
| 1945 |
+
|
| 1946 |
+
static constexpr size_t N = sizeof...(Args);
|
| 1947 |
+
static constexpr size_t NVectorized = constexpr_sum(vectorize_arg<Args>::vectorize...);
|
| 1948 |
+
static_assert(
|
| 1949 |
+
NVectorized >= 1,
|
| 1950 |
+
"pybind11::vectorize(...) requires a function with at least one vectorizable argument");
|
| 1951 |
+
|
| 1952 |
+
public:
|
| 1953 |
+
template <typename T,
|
| 1954 |
+
// SFINAE to prevent shadowing the copy constructor.
|
| 1955 |
+
typename = detail::enable_if_t<
|
| 1956 |
+
!std::is_same<vectorize_helper, typename std::decay<T>::type>::value>>
|
| 1957 |
+
explicit vectorize_helper(T &&f) : f(std::forward<T>(f)) {}
|
| 1958 |
+
|
| 1959 |
+
object operator()(typename vectorize_arg<Args>::type... args) {
|
| 1960 |
+
return run(args...,
|
| 1961 |
+
make_index_sequence<N>(),
|
| 1962 |
+
select_indices<vectorize_arg<Args>::vectorize...>(),
|
| 1963 |
+
make_index_sequence<NVectorized>());
|
| 1964 |
+
}
|
| 1965 |
+
|
| 1966 |
+
private:
|
| 1967 |
+
remove_reference_t<Func> f;
|
| 1968 |
+
|
| 1969 |
+
// Internal compiler error in MSVC 19.16.27025.1 (Visual Studio 2017 15.9.4), when compiling
|
| 1970 |
+
// with "/permissive-" flag when arg_call_types is manually inlined.
|
| 1971 |
+
using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
|
| 1972 |
+
template <size_t Index>
|
| 1973 |
+
using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
|
| 1974 |
+
|
| 1975 |
+
using returned_array = vectorize_returned_array<Func, Return, Args...>;
|
| 1976 |
+
|
| 1977 |
+
// Runs a vectorized function given arguments tuple and three index sequences:
|
| 1978 |
+
// - Index is the full set of 0 ... (N-1) argument indices;
|
| 1979 |
+
// - VIndex is the subset of argument indices with vectorized parameters, letting us access
|
| 1980 |
+
// vectorized arguments (anything not in this sequence is passed through)
|
| 1981 |
+
// - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that
|
| 1982 |
+
// we can store vectorized buffer_infos in an array (argument VIndex has its buffer at
|
| 1983 |
+
// index BIndex in the array).
|
| 1984 |
+
template <size_t... Index, size_t... VIndex, size_t... BIndex>
|
| 1985 |
+
object run(typename vectorize_arg<Args>::type &...args,
|
| 1986 |
+
index_sequence<Index...> i_seq,
|
| 1987 |
+
index_sequence<VIndex...> vi_seq,
|
| 1988 |
+
index_sequence<BIndex...> bi_seq) {
|
| 1989 |
+
|
| 1990 |
+
// Pointers to values the function was called with; the vectorized ones set here will start
|
| 1991 |
+
// out as array_t<T> pointers, but they will be changed them to T pointers before we make
|
| 1992 |
+
// call the wrapped function. Non-vectorized pointers are left as-is.
|
| 1993 |
+
std::array<void *, N> params{{reinterpret_cast<void *>(&args)...}};
|
| 1994 |
+
|
| 1995 |
+
// The array of `buffer_info`s of vectorized arguments:
|
| 1996 |
+
std::array<buffer_info, NVectorized> buffers{
|
| 1997 |
+
{reinterpret_cast<array *>(params[VIndex])->request()...}};
|
| 1998 |
+
|
| 1999 |
+
/* Determine dimensions parameters of output array */
|
| 2000 |
+
ssize_t nd = 0;
|
| 2001 |
+
std::vector<ssize_t> shape(0);
|
| 2002 |
+
auto trivial = broadcast(buffers, nd, shape);
|
| 2003 |
+
auto ndim = (size_t) nd;
|
| 2004 |
+
|
| 2005 |
+
size_t size
|
| 2006 |
+
= std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies<size_t>());
|
| 2007 |
+
|
| 2008 |
+
// If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e.
|
| 2009 |
+
// not wrapped in an array).
|
| 2010 |
+
if (size == 1 && ndim == 0) {
|
| 2011 |
+
PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
|
| 2012 |
+
return cast(
|
| 2013 |
+
returned_array::call(f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...));
|
| 2014 |
+
}
|
| 2015 |
+
|
| 2016 |
+
auto result = returned_array::create(trivial, shape);
|
| 2017 |
+
|
| 2018 |
+
PYBIND11_WARNING_PUSH
|
| 2019 |
+
#ifdef PYBIND11_DETECTED_CLANG_WITH_MISLEADING_CALL_STD_MOVE_EXPLICITLY_WARNING
|
| 2020 |
+
PYBIND11_WARNING_DISABLE_CLANG("-Wreturn-std-move")
|
| 2021 |
+
#endif
|
| 2022 |
+
|
| 2023 |
+
if (size == 0) {
|
| 2024 |
+
return result;
|
| 2025 |
+
}
|
| 2026 |
+
|
| 2027 |
+
/* Call the function */
|
| 2028 |
+
auto *mutable_data = returned_array::mutable_data(result);
|
| 2029 |
+
if (trivial == broadcast_trivial::non_trivial) {
|
| 2030 |
+
apply_broadcast(buffers, params, mutable_data, size, shape, i_seq, vi_seq, bi_seq);
|
| 2031 |
+
} else {
|
| 2032 |
+
apply_trivial(buffers, params, mutable_data, size, i_seq, vi_seq, bi_seq);
|
| 2033 |
+
}
|
| 2034 |
+
|
| 2035 |
+
return result;
|
| 2036 |
+
PYBIND11_WARNING_POP
|
| 2037 |
+
}
|
| 2038 |
+
|
| 2039 |
+
template <size_t... Index, size_t... VIndex, size_t... BIndex>
|
| 2040 |
+
void apply_trivial(std::array<buffer_info, NVectorized> &buffers,
|
| 2041 |
+
std::array<void *, N> ¶ms,
|
| 2042 |
+
Return *out,
|
| 2043 |
+
size_t size,
|
| 2044 |
+
index_sequence<Index...>,
|
| 2045 |
+
index_sequence<VIndex...>,
|
| 2046 |
+
index_sequence<BIndex...>) {
|
| 2047 |
+
|
| 2048 |
+
// Initialize an array of mutable byte references and sizes with references set to the
|
| 2049 |
+
// appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size
|
| 2050 |
+
// (except for singletons, which get an increment of 0).
|
| 2051 |
+
std::array<std::pair<unsigned char *&, const size_t>, NVectorized> vecparams{
|
| 2052 |
+
{std::pair<unsigned char *&, const size_t>(
|
| 2053 |
+
reinterpret_cast<unsigned char *&>(params[VIndex] = buffers[BIndex].ptr),
|
| 2054 |
+
buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t<VIndex>))...}};
|
| 2055 |
+
|
| 2056 |
+
for (size_t i = 0; i < size; ++i) {
|
| 2057 |
+
returned_array::call(
|
| 2058 |
+
out, i, f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...);
|
| 2059 |
+
for (auto &x : vecparams) {
|
| 2060 |
+
x.first += x.second;
|
| 2061 |
+
}
|
| 2062 |
+
}
|
| 2063 |
+
}
|
| 2064 |
+
|
| 2065 |
+
template <size_t... Index, size_t... VIndex, size_t... BIndex>
|
| 2066 |
+
void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
|
| 2067 |
+
std::array<void *, N> ¶ms,
|
| 2068 |
+
Return *out,
|
| 2069 |
+
size_t size,
|
| 2070 |
+
const std::vector<ssize_t> &output_shape,
|
| 2071 |
+
index_sequence<Index...>,
|
| 2072 |
+
index_sequence<VIndex...>,
|
| 2073 |
+
index_sequence<BIndex...>) {
|
| 2074 |
+
|
| 2075 |
+
multi_array_iterator<NVectorized> input_iter(buffers, output_shape);
|
| 2076 |
+
|
| 2077 |
+
for (size_t i = 0; i < size; ++i, ++input_iter) {
|
| 2078 |
+
PYBIND11_EXPAND_SIDE_EFFECTS((params[VIndex] = input_iter.template data<BIndex>()));
|
| 2079 |
+
returned_array::call(
|
| 2080 |
+
out, i, f, *reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
|
| 2081 |
+
}
|
| 2082 |
+
}
|
| 2083 |
+
};
|
| 2084 |
+
|
| 2085 |
+
template <typename Func, typename Return, typename... Args>
|
| 2086 |
+
vectorize_helper<Func, Return, Args...> vectorize_extractor(const Func &f, Return (*)(Args...)) {
|
| 2087 |
+
return detail::vectorize_helper<Func, Return, Args...>(f);
|
| 2088 |
+
}
|
| 2089 |
+
|
| 2090 |
+
template <typename T, int Flags>
|
| 2091 |
+
struct handle_type_name<array_t<T, Flags>> {
|
| 2092 |
+
static constexpr auto name
|
| 2093 |
+
= const_name("numpy.ndarray[") + npy_format_descriptor<T>::name + const_name("]");
|
| 2094 |
+
};
|
| 2095 |
+
|
| 2096 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 2097 |
+
|
| 2098 |
+
// Vanilla pointer vectorizer:
|
| 2099 |
+
template <typename Return, typename... Args>
|
| 2100 |
+
detail::vectorize_helper<Return (*)(Args...), Return, Args...> vectorize(Return (*f)(Args...)) {
|
| 2101 |
+
return detail::vectorize_helper<Return (*)(Args...), Return, Args...>(f);
|
| 2102 |
+
}
|
| 2103 |
+
|
| 2104 |
+
// lambda vectorizer:
|
| 2105 |
+
template <typename Func, detail::enable_if_t<detail::is_lambda<Func>::value, int> = 0>
|
| 2106 |
+
auto vectorize(Func &&f)
|
| 2107 |
+
-> decltype(detail::vectorize_extractor(std::forward<Func>(f),
|
| 2108 |
+
(detail::function_signature_t<Func> *) nullptr)) {
|
| 2109 |
+
return detail::vectorize_extractor(std::forward<Func>(f),
|
| 2110 |
+
(detail::function_signature_t<Func> *) nullptr);
|
| 2111 |
+
}
|
| 2112 |
+
|
| 2113 |
+
// Vectorize a class method (non-const):
|
| 2114 |
+
template <typename Return,
|
| 2115 |
+
typename Class,
|
| 2116 |
+
typename... Args,
|
| 2117 |
+
typename Helper = detail::vectorize_helper<
|
| 2118 |
+
decltype(std::mem_fn(std::declval<Return (Class::*)(Args...)>())),
|
| 2119 |
+
Return,
|
| 2120 |
+
Class *,
|
| 2121 |
+
Args...>>
|
| 2122 |
+
Helper vectorize(Return (Class::*f)(Args...)) {
|
| 2123 |
+
return Helper(std::mem_fn(f));
|
| 2124 |
+
}
|
| 2125 |
+
|
| 2126 |
+
// Vectorize a class method (const):
|
| 2127 |
+
template <typename Return,
|
| 2128 |
+
typename Class,
|
| 2129 |
+
typename... Args,
|
| 2130 |
+
typename Helper = detail::vectorize_helper<
|
| 2131 |
+
decltype(std::mem_fn(std::declval<Return (Class::*)(Args...) const>())),
|
| 2132 |
+
Return,
|
| 2133 |
+
const Class *,
|
| 2134 |
+
Args...>>
|
| 2135 |
+
Helper vectorize(Return (Class::*f)(Args...) const) {
|
| 2136 |
+
return Helper(std::mem_fn(f));
|
| 2137 |
+
}
|
| 2138 |
+
|
| 2139 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/operators.h
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/operator.h: Metatemplates for operator overloading
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "pybind11.h"
|
| 13 |
+
|
| 14 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 15 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 16 |
+
|
| 17 |
+
/// Enumeration with all supported operator types
|
| 18 |
+
enum op_id : int {
|
| 19 |
+
op_add,
|
| 20 |
+
op_sub,
|
| 21 |
+
op_mul,
|
| 22 |
+
op_div,
|
| 23 |
+
op_mod,
|
| 24 |
+
op_divmod,
|
| 25 |
+
op_pow,
|
| 26 |
+
op_lshift,
|
| 27 |
+
op_rshift,
|
| 28 |
+
op_and,
|
| 29 |
+
op_xor,
|
| 30 |
+
op_or,
|
| 31 |
+
op_neg,
|
| 32 |
+
op_pos,
|
| 33 |
+
op_abs,
|
| 34 |
+
op_invert,
|
| 35 |
+
op_int,
|
| 36 |
+
op_long,
|
| 37 |
+
op_float,
|
| 38 |
+
op_str,
|
| 39 |
+
op_cmp,
|
| 40 |
+
op_gt,
|
| 41 |
+
op_ge,
|
| 42 |
+
op_lt,
|
| 43 |
+
op_le,
|
| 44 |
+
op_eq,
|
| 45 |
+
op_ne,
|
| 46 |
+
op_iadd,
|
| 47 |
+
op_isub,
|
| 48 |
+
op_imul,
|
| 49 |
+
op_idiv,
|
| 50 |
+
op_imod,
|
| 51 |
+
op_ilshift,
|
| 52 |
+
op_irshift,
|
| 53 |
+
op_iand,
|
| 54 |
+
op_ixor,
|
| 55 |
+
op_ior,
|
| 56 |
+
op_complex,
|
| 57 |
+
op_bool,
|
| 58 |
+
op_nonzero,
|
| 59 |
+
op_repr,
|
| 60 |
+
op_truediv,
|
| 61 |
+
op_itruediv,
|
| 62 |
+
op_hash
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
+
enum op_type : int {
|
| 66 |
+
op_l, /* base type on left */
|
| 67 |
+
op_r, /* base type on right */
|
| 68 |
+
op_u /* unary operator */
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
struct self_t {};
|
| 72 |
+
static const self_t self = self_t();
|
| 73 |
+
|
| 74 |
+
/// Type for an unused type slot
|
| 75 |
+
struct undefined_t {};
|
| 76 |
+
|
| 77 |
+
/// Don't warn about an unused variable
|
| 78 |
+
inline self_t __self() { return self; }
|
| 79 |
+
|
| 80 |
+
/// base template of operator implementations
|
| 81 |
+
template <op_id, op_type, typename B, typename L, typename R>
|
| 82 |
+
struct op_impl {};
|
| 83 |
+
|
| 84 |
+
/// Operator implementation generator
|
| 85 |
+
template <op_id id, op_type ot, typename L, typename R>
|
| 86 |
+
struct op_ {
|
| 87 |
+
static constexpr bool op_enable_if_hook = true;
|
| 88 |
+
template <typename Class, typename... Extra>
|
| 89 |
+
void execute(Class &cl, const Extra &...extra) const {
|
| 90 |
+
using Base = typename Class::type;
|
| 91 |
+
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
|
| 92 |
+
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
|
| 93 |
+
using op = op_impl<id, ot, Base, L_type, R_type>;
|
| 94 |
+
cl.def(op::name(), &op::execute, is_operator(), extra...);
|
| 95 |
+
}
|
| 96 |
+
template <typename Class, typename... Extra>
|
| 97 |
+
void execute_cast(Class &cl, const Extra &...extra) const {
|
| 98 |
+
using Base = typename Class::type;
|
| 99 |
+
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
|
| 100 |
+
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
|
| 101 |
+
using op = op_impl<id, ot, Base, L_type, R_type>;
|
| 102 |
+
cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
|
| 103 |
+
}
|
| 104 |
+
};
|
| 105 |
+
|
| 106 |
+
#define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \
|
| 107 |
+
template <typename B, typename L, typename R> \
|
| 108 |
+
struct op_impl<op_##id, op_l, B, L, R> { \
|
| 109 |
+
static char const *name() { return "__" #id "__"; } \
|
| 110 |
+
static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \
|
| 111 |
+
static B execute_cast(const L &l, const R &r) { return B(expr); } \
|
| 112 |
+
}; \
|
| 113 |
+
template <typename B, typename L, typename R> \
|
| 114 |
+
struct op_impl<op_##id, op_r, B, L, R> { \
|
| 115 |
+
static char const *name() { return "__" #rid "__"; } \
|
| 116 |
+
static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \
|
| 117 |
+
static B execute_cast(const R &r, const L &l) { return B(expr); } \
|
| 118 |
+
}; \
|
| 119 |
+
inline op_<op_##id, op_l, self_t, self_t> op(const self_t &, const self_t &) { \
|
| 120 |
+
return op_<op_##id, op_l, self_t, self_t>(); \
|
| 121 |
+
} \
|
| 122 |
+
template <typename T> \
|
| 123 |
+
op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
|
| 124 |
+
return op_<op_##id, op_l, self_t, T>(); \
|
| 125 |
+
} \
|
| 126 |
+
template <typename T> \
|
| 127 |
+
op_<op_##id, op_r, T, self_t> op(const T &, const self_t &) { \
|
| 128 |
+
return op_<op_##id, op_r, T, self_t>(); \
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
#define PYBIND11_INPLACE_OPERATOR(id, op, expr) \
|
| 132 |
+
template <typename B, typename L, typename R> \
|
| 133 |
+
struct op_impl<op_##id, op_l, B, L, R> { \
|
| 134 |
+
static char const *name() { return "__" #id "__"; } \
|
| 135 |
+
static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \
|
| 136 |
+
static B execute_cast(L &l, const R &r) { return B(expr); } \
|
| 137 |
+
}; \
|
| 138 |
+
template <typename T> \
|
| 139 |
+
op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
|
| 140 |
+
return op_<op_##id, op_l, self_t, T>(); \
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
#define PYBIND11_UNARY_OPERATOR(id, op, expr) \
|
| 144 |
+
template <typename B, typename L> \
|
| 145 |
+
struct op_impl<op_##id, op_u, B, L, undefined_t> { \
|
| 146 |
+
static char const *name() { return "__" #id "__"; } \
|
| 147 |
+
static auto execute(const L &l) -> decltype(expr) { return expr; } \
|
| 148 |
+
static B execute_cast(const L &l) { return B(expr); } \
|
| 149 |
+
}; \
|
| 150 |
+
inline op_<op_##id, op_u, self_t, undefined_t> op(const self_t &) { \
|
| 151 |
+
return op_<op_##id, op_u, self_t, undefined_t>(); \
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r)
|
| 155 |
+
PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r)
|
| 156 |
+
PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l *r)
|
| 157 |
+
PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r)
|
| 158 |
+
PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r)
|
| 159 |
+
PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r)
|
| 160 |
+
PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r)
|
| 161 |
+
PYBIND11_BINARY_OPERATOR(and, rand, operator&, l &r)
|
| 162 |
+
PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r)
|
| 163 |
+
PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r)
|
| 164 |
+
PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r)
|
| 165 |
+
PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r)
|
| 166 |
+
PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r)
|
| 167 |
+
PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r)
|
| 168 |
+
PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r)
|
| 169 |
+
PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r)
|
| 170 |
+
// PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r))
|
| 171 |
+
PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r)
|
| 172 |
+
PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r)
|
| 173 |
+
PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r)
|
| 174 |
+
PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r)
|
| 175 |
+
PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r)
|
| 176 |
+
PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r)
|
| 177 |
+
PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r)
|
| 178 |
+
PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r)
|
| 179 |
+
PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r)
|
| 180 |
+
PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r)
|
| 181 |
+
PYBIND11_UNARY_OPERATOR(neg, operator-, -l)
|
| 182 |
+
PYBIND11_UNARY_OPERATOR(pos, operator+, +l)
|
| 183 |
+
// WARNING: This usage of `abs` should only be done for existing STL overloads.
|
| 184 |
+
// Adding overloads directly in to the `std::` namespace is advised against:
|
| 185 |
+
// https://en.cppreference.com/w/cpp/language/extending_std
|
| 186 |
+
PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l))
|
| 187 |
+
PYBIND11_UNARY_OPERATOR(hash, hash, std::hash<L>()(l))
|
| 188 |
+
PYBIND11_UNARY_OPERATOR(invert, operator~, (~l))
|
| 189 |
+
PYBIND11_UNARY_OPERATOR(bool, operator!, !!l)
|
| 190 |
+
PYBIND11_UNARY_OPERATOR(int, int_, (int) l)
|
| 191 |
+
PYBIND11_UNARY_OPERATOR(float, float_, (double) l)
|
| 192 |
+
|
| 193 |
+
#undef PYBIND11_BINARY_OPERATOR
|
| 194 |
+
#undef PYBIND11_INPLACE_OPERATOR
|
| 195 |
+
#undef PYBIND11_UNARY_OPERATOR
|
| 196 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 197 |
+
|
| 198 |
+
using detail::self;
|
| 199 |
+
// Add named operators so that they are accessible via `py::`.
|
| 200 |
+
using detail::hash;
|
| 201 |
+
|
| 202 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/options.h
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/options.h: global settings that are configurable at runtime.
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "detail/common.h"
|
| 13 |
+
|
| 14 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 15 |
+
|
| 16 |
+
class options {
|
| 17 |
+
public:
|
| 18 |
+
// Default RAII constructor, which leaves settings as they currently are.
|
| 19 |
+
options() : previous_state(global_state()) {}
|
| 20 |
+
|
| 21 |
+
// Class is non-copyable.
|
| 22 |
+
options(const options &) = delete;
|
| 23 |
+
options &operator=(const options &) = delete;
|
| 24 |
+
|
| 25 |
+
// Destructor, which restores settings that were in effect before.
|
| 26 |
+
~options() { global_state() = previous_state; }
|
| 27 |
+
|
| 28 |
+
// Setter methods (affect the global state):
|
| 29 |
+
|
| 30 |
+
options &disable_user_defined_docstrings() & {
|
| 31 |
+
global_state().show_user_defined_docstrings = false;
|
| 32 |
+
return *this;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
options &enable_user_defined_docstrings() & {
|
| 36 |
+
global_state().show_user_defined_docstrings = true;
|
| 37 |
+
return *this;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
options &disable_function_signatures() & {
|
| 41 |
+
global_state().show_function_signatures = false;
|
| 42 |
+
return *this;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
options &enable_function_signatures() & {
|
| 46 |
+
global_state().show_function_signatures = true;
|
| 47 |
+
return *this;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
options &disable_enum_members_docstring() & {
|
| 51 |
+
global_state().show_enum_members_docstring = false;
|
| 52 |
+
return *this;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
options &enable_enum_members_docstring() & {
|
| 56 |
+
global_state().show_enum_members_docstring = true;
|
| 57 |
+
return *this;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
// Getter methods (return the global state):
|
| 61 |
+
|
| 62 |
+
static bool show_user_defined_docstrings() {
|
| 63 |
+
return global_state().show_user_defined_docstrings;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
static bool show_function_signatures() { return global_state().show_function_signatures; }
|
| 67 |
+
|
| 68 |
+
static bool show_enum_members_docstring() {
|
| 69 |
+
return global_state().show_enum_members_docstring;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
// This type is not meant to be allocated on the heap.
|
| 73 |
+
void *operator new(size_t) = delete;
|
| 74 |
+
|
| 75 |
+
private:
|
| 76 |
+
struct state {
|
| 77 |
+
bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings.
|
| 78 |
+
bool show_function_signatures = true; //< Include auto-generated function signatures
|
| 79 |
+
// in docstrings.
|
| 80 |
+
bool show_enum_members_docstring = true; //< Include auto-generated member list in enum
|
| 81 |
+
// docstrings.
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
static state &global_state() {
|
| 85 |
+
static state instance;
|
| 86 |
+
return instance;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
state previous_state;
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
phivenv/Lib/site-packages/torch/include/pybind11/pybind11.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
phivenv/Lib/site-packages/torch/include/pybind11/pytypes.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
phivenv/Lib/site-packages/torch/include/pybind11/stl.h
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
pybind11/stl.h: Transparent conversion for STL data types
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
| 5 |
+
|
| 6 |
+
All rights reserved. Use of this source code is governed by a
|
| 7 |
+
BSD-style license that can be found in the LICENSE file.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "pybind11.h"
|
| 13 |
+
#include "detail/common.h"
|
| 14 |
+
|
| 15 |
+
#include <deque>
|
| 16 |
+
#include <list>
|
| 17 |
+
#include <map>
|
| 18 |
+
#include <ostream>
|
| 19 |
+
#include <set>
|
| 20 |
+
#include <unordered_map>
|
| 21 |
+
#include <unordered_set>
|
| 22 |
+
#include <valarray>
|
| 23 |
+
|
| 24 |
+
// See `detail/common.h` for implementation of these guards.
|
| 25 |
+
#if defined(PYBIND11_HAS_OPTIONAL)
|
| 26 |
+
# include <optional>
|
| 27 |
+
#elif defined(PYBIND11_HAS_EXP_OPTIONAL)
|
| 28 |
+
# include <experimental/optional>
|
| 29 |
+
#endif
|
| 30 |
+
|
| 31 |
+
#if defined(PYBIND11_HAS_VARIANT)
|
| 32 |
+
# include <variant>
|
| 33 |
+
#endif
|
| 34 |
+
|
| 35 |
+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
| 36 |
+
PYBIND11_NAMESPACE_BEGIN(detail)
|
| 37 |
+
|
| 38 |
+
/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for
|
| 39 |
+
/// forwarding a container element). Typically used indirect via forwarded_type(), below.
|
| 40 |
+
template <typename T, typename U>
|
| 41 |
+
using forwarded_type = conditional_t<std::is_lvalue_reference<T>::value,
|
| 42 |
+
remove_reference_t<U> &,
|
| 43 |
+
remove_reference_t<U> &&>;
|
| 44 |
+
|
| 45 |
+
/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically
|
| 46 |
+
/// used for forwarding a container's elements.
|
| 47 |
+
template <typename T, typename U>
|
| 48 |
+
constexpr forwarded_type<T, U> forward_like(U &&u) {
|
| 49 |
+
return std::forward<detail::forwarded_type<T, U>>(std::forward<U>(u));
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
// Checks if a container has a STL style reserve method.
|
| 53 |
+
// This will only return true for a `reserve()` with a `void` return.
|
| 54 |
+
template <typename C>
|
| 55 |
+
using has_reserve_method = std::is_same<decltype(std::declval<C>().reserve(0)), void>;
|
| 56 |
+
|
| 57 |
+
template <typename Type, typename Key>
|
| 58 |
+
struct set_caster {
|
| 59 |
+
using type = Type;
|
| 60 |
+
using key_conv = make_caster<Key>;
|
| 61 |
+
|
| 62 |
+
private:
|
| 63 |
+
template <typename T = Type, enable_if_t<has_reserve_method<T>::value, int> = 0>
|
| 64 |
+
void reserve_maybe(const anyset &s, Type *) {
|
| 65 |
+
value.reserve(s.size());
|
| 66 |
+
}
|
| 67 |
+
void reserve_maybe(const anyset &, void *) {}
|
| 68 |
+
|
| 69 |
+
public:
|
| 70 |
+
bool load(handle src, bool convert) {
|
| 71 |
+
if (!isinstance<anyset>(src)) {
|
| 72 |
+
return false;
|
| 73 |
+
}
|
| 74 |
+
auto s = reinterpret_borrow<anyset>(src);
|
| 75 |
+
value.clear();
|
| 76 |
+
reserve_maybe(s, &value);
|
| 77 |
+
for (auto entry : s) {
|
| 78 |
+
key_conv conv;
|
| 79 |
+
if (!conv.load(entry, convert)) {
|
| 80 |
+
return false;
|
| 81 |
+
}
|
| 82 |
+
value.insert(cast_op<Key &&>(std::move(conv)));
|
| 83 |
+
}
|
| 84 |
+
return true;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
template <typename T>
|
| 88 |
+
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
| 89 |
+
if (!std::is_lvalue_reference<T>::value) {
|
| 90 |
+
policy = return_value_policy_override<Key>::policy(policy);
|
| 91 |
+
}
|
| 92 |
+
pybind11::set s;
|
| 93 |
+
for (auto &&value : src) {
|
| 94 |
+
auto value_ = reinterpret_steal<object>(
|
| 95 |
+
key_conv::cast(detail::forward_like<T>(value), policy, parent));
|
| 96 |
+
if (!value_ || !s.add(std::move(value_))) {
|
| 97 |
+
return handle();
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
return s.release();
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
PYBIND11_TYPE_CASTER(type, const_name("set[") + key_conv::name + const_name("]"));
|
| 104 |
+
};
|
| 105 |
+
|
| 106 |
+
template <typename Type, typename Key, typename Value>
|
| 107 |
+
struct map_caster {
|
| 108 |
+
using key_conv = make_caster<Key>;
|
| 109 |
+
using value_conv = make_caster<Value>;
|
| 110 |
+
|
| 111 |
+
private:
|
| 112 |
+
template <typename T = Type, enable_if_t<has_reserve_method<T>::value, int> = 0>
|
| 113 |
+
void reserve_maybe(const dict &d, Type *) {
|
| 114 |
+
value.reserve(d.size());
|
| 115 |
+
}
|
| 116 |
+
void reserve_maybe(const dict &, void *) {}
|
| 117 |
+
|
| 118 |
+
public:
|
| 119 |
+
bool load(handle src, bool convert) {
|
| 120 |
+
if (!isinstance<dict>(src)) {
|
| 121 |
+
return false;
|
| 122 |
+
}
|
| 123 |
+
auto d = reinterpret_borrow<dict>(src);
|
| 124 |
+
value.clear();
|
| 125 |
+
reserve_maybe(d, &value);
|
| 126 |
+
for (auto it : d) {
|
| 127 |
+
key_conv kconv;
|
| 128 |
+
value_conv vconv;
|
| 129 |
+
if (!kconv.load(it.first.ptr(), convert) || !vconv.load(it.second.ptr(), convert)) {
|
| 130 |
+
return false;
|
| 131 |
+
}
|
| 132 |
+
value.emplace(cast_op<Key &&>(std::move(kconv)), cast_op<Value &&>(std::move(vconv)));
|
| 133 |
+
}
|
| 134 |
+
return true;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
template <typename T>
|
| 138 |
+
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
| 139 |
+
dict d;
|
| 140 |
+
return_value_policy policy_key = policy;
|
| 141 |
+
return_value_policy policy_value = policy;
|
| 142 |
+
if (!std::is_lvalue_reference<T>::value) {
|
| 143 |
+
policy_key = return_value_policy_override<Key>::policy(policy_key);
|
| 144 |
+
policy_value = return_value_policy_override<Value>::policy(policy_value);
|
| 145 |
+
}
|
| 146 |
+
for (auto &&kv : src) {
|
| 147 |
+
auto key = reinterpret_steal<object>(
|
| 148 |
+
key_conv::cast(detail::forward_like<T>(kv.first), policy_key, parent));
|
| 149 |
+
auto value = reinterpret_steal<object>(
|
| 150 |
+
value_conv::cast(detail::forward_like<T>(kv.second), policy_value, parent));
|
| 151 |
+
if (!key || !value) {
|
| 152 |
+
return handle();
|
| 153 |
+
}
|
| 154 |
+
d[std::move(key)] = std::move(value);
|
| 155 |
+
}
|
| 156 |
+
return d.release();
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
PYBIND11_TYPE_CASTER(Type,
|
| 160 |
+
const_name("dict[") + key_conv::name + const_name(", ") + value_conv::name
|
| 161 |
+
+ const_name("]"));
|
| 162 |
+
};
|
| 163 |
+
|
| 164 |
+
template <typename Type, typename Value>
|
| 165 |
+
struct list_caster {
|
| 166 |
+
using value_conv = make_caster<Value>;
|
| 167 |
+
|
| 168 |
+
bool load(handle src, bool convert) {
|
| 169 |
+
if (!isinstance<sequence>(src) || isinstance<bytes>(src) || isinstance<str>(src)) {
|
| 170 |
+
return false;
|
| 171 |
+
}
|
| 172 |
+
auto s = reinterpret_borrow<sequence>(src);
|
| 173 |
+
value.clear();
|
| 174 |
+
reserve_maybe(s, &value);
|
| 175 |
+
for (const auto &it : s) {
|
| 176 |
+
value_conv conv;
|
| 177 |
+
if (!conv.load(it, convert)) {
|
| 178 |
+
return false;
|
| 179 |
+
}
|
| 180 |
+
value.push_back(cast_op<Value &&>(std::move(conv)));
|
| 181 |
+
}
|
| 182 |
+
return true;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
private:
|
| 186 |
+
template <typename T = Type, enable_if_t<has_reserve_method<T>::value, int> = 0>
|
| 187 |
+
void reserve_maybe(const sequence &s, Type *) {
|
| 188 |
+
value.reserve(s.size());
|
| 189 |
+
}
|
| 190 |
+
void reserve_maybe(const sequence &, void *) {}
|
| 191 |
+
|
| 192 |
+
public:
|
| 193 |
+
template <typename T>
|
| 194 |
+
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
| 195 |
+
if (!std::is_lvalue_reference<T>::value) {
|
| 196 |
+
policy = return_value_policy_override<Value>::policy(policy);
|
| 197 |
+
}
|
| 198 |
+
list l(src.size());
|
| 199 |
+
ssize_t index = 0;
|
| 200 |
+
for (auto &&value : src) {
|
| 201 |
+
auto value_ = reinterpret_steal<object>(
|
| 202 |
+
value_conv::cast(detail::forward_like<T>(value), policy, parent));
|
| 203 |
+
if (!value_) {
|
| 204 |
+
return handle();
|
| 205 |
+
}
|
| 206 |
+
PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference
|
| 207 |
+
}
|
| 208 |
+
return l.release();
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
PYBIND11_TYPE_CASTER(Type, const_name("list[") + value_conv::name + const_name("]"));
|
| 212 |
+
};
|
| 213 |
+
|
| 214 |
+
template <typename Type, typename Alloc>
|
| 215 |
+
struct type_caster<std::vector<Type, Alloc>> : list_caster<std::vector<Type, Alloc>, Type> {};
|
| 216 |
+
|
| 217 |
+
template <typename Type, typename Alloc>
|
| 218 |
+
struct type_caster<std::deque<Type, Alloc>> : list_caster<std::deque<Type, Alloc>, Type> {};
|
| 219 |
+
|
| 220 |
+
template <typename Type, typename Alloc>
|
| 221 |
+
struct type_caster<std::list<Type, Alloc>> : list_caster<std::list<Type, Alloc>, Type> {};
|
| 222 |
+
|
| 223 |
+
template <typename ArrayType, typename Value, bool Resizable, size_t Size = 0>
|
| 224 |
+
struct array_caster {
|
| 225 |
+
using value_conv = make_caster<Value>;
|
| 226 |
+
|
| 227 |
+
private:
|
| 228 |
+
template <bool R = Resizable>
|
| 229 |
+
bool require_size(enable_if_t<R, size_t> size) {
|
| 230 |
+
if (value.size() != size) {
|
| 231 |
+
value.resize(size);
|
| 232 |
+
}
|
| 233 |
+
return true;
|
| 234 |
+
}
|
| 235 |
+
template <bool R = Resizable>
|
| 236 |
+
bool require_size(enable_if_t<!R, size_t> size) {
|
| 237 |
+
return size == Size;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
public:
|
| 241 |
+
bool load(handle src, bool convert) {
|
| 242 |
+
if (!isinstance<sequence>(src)) {
|
| 243 |
+
return false;
|
| 244 |
+
}
|
| 245 |
+
auto l = reinterpret_borrow<sequence>(src);
|
| 246 |
+
if (!require_size(l.size())) {
|
| 247 |
+
return false;
|
| 248 |
+
}
|
| 249 |
+
size_t ctr = 0;
|
| 250 |
+
for (const auto &it : l) {
|
| 251 |
+
value_conv conv;
|
| 252 |
+
if (!conv.load(it, convert)) {
|
| 253 |
+
return false;
|
| 254 |
+
}
|
| 255 |
+
value[ctr++] = cast_op<Value &&>(std::move(conv));
|
| 256 |
+
}
|
| 257 |
+
return true;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
template <typename T>
|
| 261 |
+
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
| 262 |
+
list l(src.size());
|
| 263 |
+
ssize_t index = 0;
|
| 264 |
+
for (auto &&value : src) {
|
| 265 |
+
auto value_ = reinterpret_steal<object>(
|
| 266 |
+
value_conv::cast(detail::forward_like<T>(value), policy, parent));
|
| 267 |
+
if (!value_) {
|
| 268 |
+
return handle();
|
| 269 |
+
}
|
| 270 |
+
PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference
|
| 271 |
+
}
|
| 272 |
+
return l.release();
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
PYBIND11_TYPE_CASTER(ArrayType,
|
| 276 |
+
const_name<Resizable>(const_name(""), const_name("Annotated["))
|
| 277 |
+
+ const_name("list[") + value_conv::name + const_name("]")
|
| 278 |
+
+ const_name<Resizable>(const_name(""),
|
| 279 |
+
const_name(", FixedSize(")
|
| 280 |
+
+ const_name<Size>() + const_name(")]")));
|
| 281 |
+
};
|
| 282 |
+
|
| 283 |
+
template <typename Type, size_t Size>
|
| 284 |
+
struct type_caster<std::array<Type, Size>>
|
| 285 |
+
: array_caster<std::array<Type, Size>, Type, false, Size> {};
|
| 286 |
+
|
| 287 |
+
template <typename Type>
|
| 288 |
+
struct type_caster<std::valarray<Type>> : array_caster<std::valarray<Type>, Type, true> {};
|
| 289 |
+
|
| 290 |
+
template <typename Key, typename Compare, typename Alloc>
|
| 291 |
+
struct type_caster<std::set<Key, Compare, Alloc>>
|
| 292 |
+
: set_caster<std::set<Key, Compare, Alloc>, Key> {};
|
| 293 |
+
|
| 294 |
+
template <typename Key, typename Hash, typename Equal, typename Alloc>
|
| 295 |
+
struct type_caster<std::unordered_set<Key, Hash, Equal, Alloc>>
|
| 296 |
+
: set_caster<std::unordered_set<Key, Hash, Equal, Alloc>, Key> {};
|
| 297 |
+
|
| 298 |
+
template <typename Key, typename Value, typename Compare, typename Alloc>
|
| 299 |
+
struct type_caster<std::map<Key, Value, Compare, Alloc>>
|
| 300 |
+
: map_caster<std::map<Key, Value, Compare, Alloc>, Key, Value> {};
|
| 301 |
+
|
| 302 |
+
template <typename Key, typename Value, typename Hash, typename Equal, typename Alloc>
|
| 303 |
+
struct type_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>>
|
| 304 |
+
: map_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>, Key, Value> {};
|
| 305 |
+
|
| 306 |
+
// This type caster is intended to be used for std::optional and std::experimental::optional
|
| 307 |
+
template <typename Type, typename Value = typename Type::value_type>
|
| 308 |
+
struct optional_caster {
|
| 309 |
+
using value_conv = make_caster<Value>;
|
| 310 |
+
|
| 311 |
+
template <typename T>
|
| 312 |
+
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
| 313 |
+
if (!src) {
|
| 314 |
+
return none().release();
|
| 315 |
+
}
|
| 316 |
+
if (!std::is_lvalue_reference<T>::value) {
|
| 317 |
+
policy = return_value_policy_override<Value>::policy(policy);
|
| 318 |
+
}
|
| 319 |
+
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
| 320 |
+
return value_conv::cast(*std::forward<T>(src), policy, parent);
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
bool load(handle src, bool convert) {
|
| 324 |
+
if (!src) {
|
| 325 |
+
return false;
|
| 326 |
+
}
|
| 327 |
+
if (src.is_none()) {
|
| 328 |
+
return true; // default-constructed value is already empty
|
| 329 |
+
}
|
| 330 |
+
value_conv inner_caster;
|
| 331 |
+
if (!inner_caster.load(src, convert)) {
|
| 332 |
+
return false;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
value.emplace(cast_op<Value &&>(std::move(inner_caster)));
|
| 336 |
+
return true;
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
PYBIND11_TYPE_CASTER(Type, const_name("Optional[") + value_conv::name + const_name("]"));
|
| 340 |
+
};
|
| 341 |
+
|
| 342 |
+
#if defined(PYBIND11_HAS_OPTIONAL)
|
| 343 |
+
template <typename T>
|
| 344 |
+
struct type_caster<std::optional<T>> : public optional_caster<std::optional<T>> {};
|
| 345 |
+
|
| 346 |
+
template <>
|
| 347 |
+
struct type_caster<std::nullopt_t> : public void_caster<std::nullopt_t> {};
|
| 348 |
+
#endif
|
| 349 |
+
|
| 350 |
+
#if defined(PYBIND11_HAS_EXP_OPTIONAL)
|
| 351 |
+
template <typename T>
|
| 352 |
+
struct type_caster<std::experimental::optional<T>>
|
| 353 |
+
: public optional_caster<std::experimental::optional<T>> {};
|
| 354 |
+
|
| 355 |
+
template <>
|
| 356 |
+
struct type_caster<std::experimental::nullopt_t>
|
| 357 |
+
: public void_caster<std::experimental::nullopt_t> {};
|
| 358 |
+
#endif
|
| 359 |
+
|
| 360 |
+
/// Visit a variant and cast any found type to Python
|
| 361 |
+
struct variant_caster_visitor {
|
| 362 |
+
return_value_policy policy;
|
| 363 |
+
handle parent;
|
| 364 |
+
|
| 365 |
+
using result_type = handle; // required by boost::variant in C++11
|
| 366 |
+
|
| 367 |
+
template <typename T>
|
| 368 |
+
result_type operator()(T &&src) const {
|
| 369 |
+
return make_caster<T>::cast(std::forward<T>(src), policy, parent);
|
| 370 |
+
}
|
| 371 |
+
};
|
| 372 |
+
|
| 373 |
+
/// Helper class which abstracts away variant's `visit` function. `std::variant` and similar
|
| 374 |
+
/// `namespace::variant` types which provide a `namespace::visit()` function are handled here
|
| 375 |
+
/// automatically using argument-dependent lookup. Users can provide specializations for other
|
| 376 |
+
/// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`.
|
| 377 |
+
template <template <typename...> class Variant>
|
| 378 |
+
struct visit_helper {
|
| 379 |
+
template <typename... Args>
|
| 380 |
+
static auto call(Args &&...args) -> decltype(visit(std::forward<Args>(args)...)) {
|
| 381 |
+
return visit(std::forward<Args>(args)...);
|
| 382 |
+
}
|
| 383 |
+
};
|
| 384 |
+
|
| 385 |
+
/// Generic variant caster
|
| 386 |
+
template <typename Variant>
|
| 387 |
+
struct variant_caster;
|
| 388 |
+
|
| 389 |
+
template <template <typename...> class V, typename... Ts>
|
| 390 |
+
struct variant_caster<V<Ts...>> {
|
| 391 |
+
static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative.");
|
| 392 |
+
|
| 393 |
+
template <typename U, typename... Us>
|
| 394 |
+
bool load_alternative(handle src, bool convert, type_list<U, Us...>) {
|
| 395 |
+
auto caster = make_caster<U>();
|
| 396 |
+
if (caster.load(src, convert)) {
|
| 397 |
+
value = cast_op<U>(std::move(caster));
|
| 398 |
+
return true;
|
| 399 |
+
}
|
| 400 |
+
return load_alternative(src, convert, type_list<Us...>{});
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
bool load_alternative(handle, bool, type_list<>) { return false; }
|
| 404 |
+
|
| 405 |
+
bool load(handle src, bool convert) {
|
| 406 |
+
// Do a first pass without conversions to improve constructor resolution.
|
| 407 |
+
// E.g. `py::int_(1).cast<variant<double, int>>()` needs to fill the `int`
|
| 408 |
+
// slot of the variant. Without two-pass loading `double` would be filled
|
| 409 |
+
// because it appears first and a conversion is possible.
|
| 410 |
+
if (convert && load_alternative(src, false, type_list<Ts...>{})) {
|
| 411 |
+
return true;
|
| 412 |
+
}
|
| 413 |
+
return load_alternative(src, convert, type_list<Ts...>{});
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
template <typename Variant>
|
| 417 |
+
static handle cast(Variant &&src, return_value_policy policy, handle parent) {
|
| 418 |
+
return visit_helper<V>::call(variant_caster_visitor{policy, parent},
|
| 419 |
+
std::forward<Variant>(src));
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
using Type = V<Ts...>;
|
| 423 |
+
PYBIND11_TYPE_CASTER(Type,
|
| 424 |
+
const_name("Union[")
|
| 425 |
+
+ ::pybind11::detail::concat(make_caster<Ts>::name...)
|
| 426 |
+
+ const_name("]"));
|
| 427 |
+
};
|
| 428 |
+
|
| 429 |
+
#if defined(PYBIND11_HAS_VARIANT)
|
| 430 |
+
template <typename... Ts>
|
| 431 |
+
struct type_caster<std::variant<Ts...>> : variant_caster<std::variant<Ts...>> {};
|
| 432 |
+
|
| 433 |
+
template <>
|
| 434 |
+
struct type_caster<std::monostate> : public void_caster<std::monostate> {};
|
| 435 |
+
#endif
|
| 436 |
+
|
| 437 |
+
PYBIND11_NAMESPACE_END(detail)
|
| 438 |
+
|
| 439 |
+
inline std::ostream &operator<<(std::ostream &os, const handle &obj) {
|
| 440 |
+
#ifdef PYBIND11_HAS_STRING_VIEW
|
| 441 |
+
os << str(obj).cast<std::string_view>();
|
| 442 |
+
#else
|
| 443 |
+
os << (std::string) str(obj);
|
| 444 |
+
#endif
|
| 445 |
+
return os;
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|