| |
| |
|
|
| #pragma once |
|
|
| #include <stddef.h> |
| #include <iostream> |
| #include <string> |
| #include <vector> |
|
|
| #include "core/common/gsl.h" |
| #include "core/common/common.h" |
| #include "core/framework/allocator.h" |
| #include "core/framework/tensor_shape.h" |
| #include "core/framework/buffer_deleter.h" |
| #include "onnxruntime_config.h" |
| #include "core/framework/data_types.h" |
| #include "core/framework/data_types_internal.h" |
|
|
| struct OrtValue; |
|
|
| namespace onnxruntime { |
|
|
| |
| #ifdef __GNUC__ |
| #pragma GCC diagnostic push |
| #ifdef HAS_NULL_DEREFERENCE |
| #pragma GCC diagnostic ignored "-Wnull-dereference" |
| #endif |
| #endif |
| |
| |
| |
| |
| |
| |
|
|
| class Tensor final { |
| public: |
| |
| |
| |
|
|
| Tensor() = default; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, |
| ptrdiff_t offset = 0, gsl::span<const int64_t> strides = {}); |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| static void InitOrtValue(MLDataType p_type, const TensorShape& shape, |
| void* p_data, const OrtMemoryInfo& location, |
| OrtValue& ort_value, ptrdiff_t offset = 0, |
| gsl::span<const int64_t> strides = {}); |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| static void InitOrtValue(MLDataType p_type, const TensorShape& shape, |
| void* p_data, std::shared_ptr<IAllocator> allocator, |
| OrtValue& ort_value, ptrdiff_t offset = 0, |
| gsl::span<const int64_t> strides = {}); |
|
|
| static size_t CalculateTensorStorageSize(MLDataType p_type, |
| const TensorShape& shape, |
| gsl::span<const int64_t> strides = {}); |
|
|
| |
| |
| |
| |
| Tensor(MLDataType p_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator, |
| gsl::span<const int64_t> strides = {}); |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| static void InitOrtValue(MLDataType elt_type, |
| const TensorShape& shape, |
| std::shared_ptr<IAllocator> allocator, |
| OrtValue& ort_value, |
| gsl::span<const int64_t> strides = {}); |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, std::shared_ptr<IAllocator> deleter, |
| ptrdiff_t offset = 0, gsl::span<const int64_t> strides = {}); |
|
|
| ~Tensor(); |
|
|
| |
| ORT_DISALLOW_COPY_AND_ASSIGNMENT(Tensor); |
|
|
| Tensor(Tensor&& other) noexcept; |
|
|
| Tensor& operator=(Tensor&& other) noexcept; |
|
|
| |
| |
| |
| MLDataType DataType() const { return dtype_; } |
|
|
| |
| |
| |
| |
| int32_t GetElementType() const { |
| return dtype_->GetDataType(); |
| } |
|
|
| |
| |
| bool IsDataTypeString() const { |
| return utils::IsPrimitiveDataType<std::string>(dtype_); |
| } |
|
|
| |
| template <class T> |
| bool IsDataType() const { |
| return utils::IsPrimitiveDataType<T>(dtype_); |
| } |
|
|
| |
| |
| |
| const TensorShape& Shape() const noexcept { return shape_; } |
|
|
| |
| |
| |
| const OrtMemoryInfo& Location() const { return alloc_info_; } |
|
|
| |
| |
| |
| template <typename T> |
| T* MutableData() { |
| |
| ORT_ENFORCE(utils::IsPrimitiveDataType<T>(dtype_), "Tensor type mismatch. ", |
| "T ", "!=", dtype_); |
| return reinterpret_cast<T*>(static_cast<char*>(p_data_) + byte_offset_); |
| } |
|
|
| |
| |
| |
| template <typename T> |
| gsl::span<T> MutableDataAsSpan() { |
| |
| ORT_ENFORCE(utils::IsPrimitiveDataType<T>(dtype_), "Tensor type mismatch. ", |
| "T ", "!=", dtype_); |
| T* data = reinterpret_cast<T*>(static_cast<char*>(p_data_) + byte_offset_); |
| return gsl::make_span(data, static_cast<size_t>(shape_.Size())); |
| } |
|
|
| template <typename T> |
| const T* Data() const { |
| |
| ORT_ENFORCE(utils::IsPrimitiveDataType<T>(dtype_), "Tensor type mismatch. ", |
| "T ", "!=", dtype_); |
| return reinterpret_cast<const T*>(static_cast<char*>(p_data_) + byte_offset_); |
| } |
|
|
| template <typename T> |
| gsl::span<const T> DataAsSpan() const { |
| |
| ORT_ENFORCE(utils::IsPrimitiveDataType<T>(dtype_), "Tensor type mismatch. ", |
| "T ", "!=", dtype_); |
| const T* data = reinterpret_cast<const T*>(static_cast<char*>(p_data_) + byte_offset_); |
| return gsl::make_span(data, static_cast<typename gsl::span<T>::size_type>(shape_.Size())); |
| } |
|
|
| void* MutableDataRaw(MLDataType type) { |
| ORT_ENFORCE(type == dtype_, "Tensor type mismatch.", type, "!=", dtype_); |
| return static_cast<char*>(p_data_) + byte_offset_; |
| } |
|
|
| const void* DataRaw(MLDataType type) const { |
| ORT_ENFORCE(type == dtype_, "Tensor type mismatch.", type, "!=", dtype_); |
| return static_cast<char*>(p_data_) + byte_offset_; |
| } |
|
|
| void* MutableDataRaw() noexcept { |
| return static_cast<char*>(p_data_) + byte_offset_; |
| } |
|
|
| const void* DataRaw() const noexcept { |
| return static_cast<char*>(p_data_) + byte_offset_; |
| } |
|
|
| bool OwnsBuffer() const noexcept { |
| return buffer_deleter_ != nullptr; |
| } |
|
|
| |
| |
| |
| |
| |
| inline void Reshape(const TensorShape& new_shape) { |
| ORT_ENFORCE(shape_.Size() == new_shape.Size(), |
| "Tensor size (" + std::to_string(shape_.Size()) + |
| ") != new size (" + std::to_string(new_shape.Size()) + ")"); |
| shape_ = new_shape; |
| } |
|
|
| |
| |
| |
| |
| |
| inline ptrdiff_t ByteOffset() const { |
| return byte_offset_; |
| } |
|
|
| |
| |
| |
| |
| inline void SetByteOffset(ptrdiff_t byte_offset) { |
| byte_offset_ = byte_offset; |
| } |
|
|
| |
| |
| |
| size_t SizeInBytes() const; |
|
|
| #ifdef ENABLE_STRIDED_TENSORS |
| |
| |
| |
| gsl::span<const int64_t> Strides() const; |
|
|
| |
| |
| |
| bool IsContiguous() const noexcept { return is_contiguous_; } |
|
|
| |
| |
| |
| void SetShapeAndStrides(const TensorShape& new_shape, gsl::span<const int64_t> new_strides); |
| #endif |
|
|
| |
| private: |
| void Init(MLDataType p_type, |
| const TensorShape& shape, |
| void* p_raw_data, |
| AllocatorPtr deleter, |
| ptrdiff_t offset = 0, |
| gsl::span<const int64_t> strides = {}); |
|
|
| void ReleaseBuffer(); |
|
|
| #ifdef ENABLE_STRIDED_TENSORS |
| bool CheckIsContiguous() const; |
| #endif |
|
|
| void* p_data_; |
| |
| |
| |
| |
| |
| AllocatorPtr buffer_deleter_; |
|
|
| TensorShape shape_; |
| #ifdef ENABLE_STRIDED_TENSORS |
| mutable TensorShapeVector strides_; |
| bool is_contiguous_ = true; |
| #endif |
|
|
| const PrimitiveDataTypeBase* dtype_; |
| OrtMemoryInfo alloc_info_; |
| ptrdiff_t byte_offset_; |
| }; |
| #ifdef __GNUC__ |
| #pragma GCC diagnostic pop |
| #endif |
| } |
|
|