|
|
|
|
|
#pragma once |
|
|
|
|
|
#include <ATen/ArrayRef.h> |
|
|
#include <ATen/FunctionalStorageImpl.h> |
|
|
#include <ATen/core/IListRef.h> |
|
|
#include <ATen/core/List.h> |
|
|
#include <ATen/core/boxing/BoxedKernel.h> |
|
|
#include <ATen/core/boxing/impl/boxing.h> |
|
|
#include <ATen/core/dispatch/Dispatcher.h> |
|
|
|
|
|
#include <c10/core/DispatchKey.h> |
|
|
|
|
|
namespace at { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { |
|
|
explicit FunctionalTensorWrapper(const Tensor& value); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explicit FunctionalTensorWrapper( |
|
|
const Tensor& view_value, |
|
|
const FunctionalTensorWrapper* base, |
|
|
functionalization::ViewMeta meta); |
|
|
|
|
|
|
|
|
|
|
|
const Tensor& value() const { |
|
|
return value_; |
|
|
}; |
|
|
|
|
|
|
|
|
int64_t level() const { |
|
|
return level_; |
|
|
}; |
|
|
void set_level(int64_t level) { |
|
|
level_ = level; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void sync_(); |
|
|
|
|
|
|
|
|
|
|
|
void regenerate_from_base(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool apply_updates(); |
|
|
|
|
|
|
|
|
void commit_update(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool is_up_to_date() const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void mutate_view_meta(at::functionalization::ViewMeta meta); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void replace_(const Tensor& other); |
|
|
|
|
|
|
|
|
void maybe_replace_storage(const Tensor& other); |
|
|
|
|
|
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
|
|
const c10::VariableVersion& version_counter, |
|
|
bool allow_tensor_metadata_change) const override; |
|
|
|
|
|
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |
|
|
c10::VariableVersion&& version_counter, |
|
|
bool allow_tensor_metadata_change) const override; |
|
|
|
|
|
~FunctionalTensorWrapper() override = default; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
at::IntArrayRef sizes_custom() const override; |
|
|
at::IntArrayRef strides_custom() const override; |
|
|
int64_t dim_custom() const override; |
|
|
int64_t numel_custom() const override; |
|
|
bool is_contiguous_custom(at::MemoryFormat memory_format) const override; |
|
|
c10::SymIntArrayRef sym_sizes_custom() const override; |
|
|
c10::SymIntArrayRef sym_strides_custom() const override; |
|
|
|
|
|
private: |
|
|
const char* tensorimpl_type_name() const override; |
|
|
void set_constructor_metadata(); |
|
|
functionalization::FunctionalStorageImpl* functional_storage_impl() const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename VariableVersion> |
|
|
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core( |
|
|
VariableVersion&& version_counter, |
|
|
bool allow_tensor_metadata_change) const; |
|
|
|
|
|
|
|
|
|
|
|
Tensor value_; |
|
|
int64_t level_; |
|
|
|
|
|
size_t generation_ = 0; |
|
|
std::vector<at::functionalization::ViewMeta> view_metas_; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
namespace functionalization { |
|
|
namespace impl { |
|
|
|
|
|
TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper( |
|
|
const Tensor& tensor) { |
|
|
auto functional_impl = |
|
|
static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl()); |
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr); |
|
|
return functional_impl; |
|
|
} |
|
|
|
|
|
TORCH_API bool isFunctionalTensor(const at::Tensor& tensor); |
|
|
TORCH_API bool isFunctionalTensor(const c10::optional<Tensor>& t); |
|
|
TORCH_API bool isFunctionalTensor( |
|
|
const c10::List<c10::optional<Tensor>>& t_list); |
|
|
TORCH_API bool isFunctionalTensor(ITensorListRef list); |
|
|
|
|
|
TORCH_API Tensor to_functional_tensor(const Tensor& tensor); |
|
|
TORCH_API c10::optional<Tensor> to_functional_tensor( |
|
|
const c10::optional<Tensor>& tensor); |
|
|
TORCH_API c10::List<c10::optional<Tensor>> to_functional_tensor( |
|
|
const c10::List<c10::optional<Tensor>>& t_list); |
|
|
TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list); |
|
|
|
|
|
TORCH_API Tensor |
|
|
from_functional_tensor(const Tensor& tensor, bool assert_functional = true); |
|
|
TORCH_API c10::optional<Tensor> from_functional_tensor( |
|
|
const c10::optional<Tensor>& t, |
|
|
bool assert_functional = true); |
|
|
TORCH_API c10::List<c10::optional<Tensor>> from_functional_tensor( |
|
|
const c10::List<c10::optional<Tensor>>& t_list); |
|
|
TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list); |
|
|
|
|
|
TORCH_API void sync(const at::Tensor& t); |
|
|
TORCH_API void sync(const c10::optional<Tensor>& t); |
|
|
TORCH_API void sync(const c10::List<c10::optional<Tensor>> t_list); |
|
|
TORCH_API void sync(ITensorListRef t_list); |
|
|
|
|
|
TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other); |
|
|
TORCH_API void replace_( |
|
|
const ITensorListRef functional_tensor, |
|
|
ITensorListRef other); |
|
|
|
|
|
TORCH_API void commit_update(const Tensor& functional_tensor); |
|
|
TORCH_API void commit_update(ITensorListRef functional_tensor); |
|
|
|
|
|
Tensor create_functional_tensor_with_view_meta( |
|
|
const Tensor& view_to_wrap, |
|
|
const Tensor& base, |
|
|
functionalization::ViewMeta meta, |
|
|
int64_t out_idx = 0); |
|
|
std::vector<Tensor> create_functional_tensor_with_view_meta( |
|
|
ITensorListRef view_to_wrap, |
|
|
const Tensor& base, |
|
|
functionalization::ViewMeta meta); |
|
|
|
|
|
void mutate_view_meta(const Tensor& self, functionalization::ViewMeta meta); |
|
|
|
|
|
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out); |
|
|
void set_sizes_strides_offset( |
|
|
const std::vector<Tensor>& outs, |
|
|
const std::vector<Tensor>& meta_outs); |
|
|
|
|
|
|
|
|
|
|
|
TORCH_API bool getFunctionalizationReapplyViewsTLS(); |
|
|
TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views); |
|
|
|
|
|
class TORCH_API FunctionalizationReapplyViewsGuard { |
|
|
public: |
|
|
FunctionalizationReapplyViewsGuard(bool reapply_views) { |
|
|
prev_ = getFunctionalizationReapplyViewsTLS(); |
|
|
setFunctionalizationReapplyViewsTLS(reapply_views); |
|
|
} |
|
|
|
|
|
~FunctionalizationReapplyViewsGuard() { |
|
|
setFunctionalizationReapplyViewsTLS(prev_); |
|
|
} |
|
|
|
|
|
FunctionalizationReapplyViewsGuard( |
|
|
const FunctionalizationReapplyViewsGuard&) = delete; |
|
|
FunctionalizationReapplyViewsGuard operator=( |
|
|
const FunctionalizationReapplyViewsGuard&) = delete; |
|
|
FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) = |
|
|
delete; |
|
|
FunctionalizationReapplyViewsGuard operator=( |
|
|
FunctionalizationReapplyViewsGuard&&) = delete; |
|
|
|
|
|
private: |
|
|
bool prev_; |
|
|
}; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
TORCH_API void functionalize_op_helper( |
|
|
const c10::OperatorHandle& op, |
|
|
torch::jit::Stack* stack); |
|
|
|
|
|
template <class Op, bool symint, class ReturnType, class... ParameterTypes> |
|
|
struct _functionalize_aten_op final {}; |
|
|
|
|
|
template <class Op, bool symint, class ReturnType, class... ParameterTypes> |
|
|
struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final { |
|
|
static ReturnType call( |
|
|
typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) { |
|
|
using FuncType = ReturnType( |
|
|
typename c10::maybe_keep_symint<symint, ParameterTypes>::type...); |
|
|
auto op = c10::Dispatcher::singleton() |
|
|
.findSchemaOrThrow( |
|
|
(const char*)Op::name, (const char*)Op::overload_name) |
|
|
.typed<FuncType>(); |
|
|
|
|
|
return c10::impl::BoxedKernelWrapper<FuncType>::call( |
|
|
c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(), |
|
|
op, |
|
|
|
|
|
|
|
|
c10::DispatchKeySet(), |
|
|
args...); |
|
|
} |
|
|
}; |
|
|
|
|
|
template <class Op> |
|
|
using functionalize_aten_op = |
|
|
_functionalize_aten_op<Op, false, typename Op::schema>; |
|
|
|
|
|
template <class Op> |
|
|
using functionalize_aten_op_symint = |
|
|
_functionalize_aten_op<Op, true, typename Op::schema>; |
|
|
|
|
|
} |
|
|
} |
|
|
|