File size: 12,440 Bytes
9dd3461 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 |
#pragma once
#include <ATen/ArrayRef.h>
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/core/IListRef.h>
#include <ATen/core/List.h>
#include <ATen/core/boxing/BoxedKernel.h>
#include <ATen/core/boxing/impl/boxing.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/core/DispatchKey.h>
namespace at {
// Note [Functionalization Pass In Core]
// The Functionalization pass is used to remove aliasing from a pytorch program.
//
// This is useful for backends that don't support aliasing, like XLA and Vulkan.
// It's also necessary in order to remove mutation from a program, which is
// needed in Functorch.
//
// Consider this program:
// a = torch.ones(...)
// b = a.view(...)
// b.add_(1)
//
// In this program, b is meant to alias with a due to the use of view(). At the
// end of the program, both a and b are full of 2's. However, backends that
// don't support aliasing aren't able to correctly implement the view()
// operator. Instead, they can opt into the Functionalization pass, which will
// sit between the user and the backend, and provide the necessary aliasing
// logic.
//
// The functionalization pass will turn the above program into a slightly
// different program that has the same semantics, transparently to the user,
// that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
// a.view_copy(...) # view() replaced with view_copy(). Backends like
// XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
// pass machinery knows that a and b are aliased - it applies b's mutation to a
// too.
//
// So, how does the functionalization pass keep track of which tensors are
// aliased? The pass works by wrapping EVERY tensor in the program inside of a
// FunctionalTensorWrapper, which knows about its alias'd tensors.
//
// See Note [Functionalization: Alias Removal] for details on the aliasing
// machinery. See Note [Functionalization: Mutation Removal] for details on
// mutation removal.
struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
explicit FunctionalTensorWrapper(const Tensor& value);
// Additional constructor to create a FunctionalTensorWrapper directly from an
// underlying tensor that was created from a view. For example, the code b =
// a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
// view1_meta)
explicit FunctionalTensorWrapper(
const Tensor& view_value,
const FunctionalTensorWrapper* base,
functionalization::ViewMeta meta);
// Get the underlying, actual tensor, that doesn't know anything about
// functionalization.
const Tensor& value() const {
return value_;
};
// The concept of "level" is only ever important to functorch; it's exposed
// here as more of a hook for functorch to use.
int64_t level() const {
return level_;
};
void set_level(int64_t level) {
level_ = level;
}
// Sync's the underlying tensor with its alias, if it's out of date. This
// involves two steps: 1) Apply any pending updates/mutations to the alias 2)
// Replay the views (if any) to regenerate the current tensor off of the
// updated alias.
void sync_();
// Performs step (1) of the sync. This is its own public API because it's
// needed by view_inplace ops like transpose_. See Note [Functionalization
// Pass - Inplace View Ops]
void regenerate_from_base();
// Performs step (2) of the sync. This is its own public API because it's
// needed by functorch. functorch wants to make sure that all input tensors to
// a functionalized program have been properly synced so it can properly
// propagate mutations to inputs. It can't just call sync_(), because the
// FunctionalTensorWrapper will look like it has no aliases and sync_ will be
// a noop. We use the reference count on storage_ to determine if the wrapper
// is aliased, and by the time functorch is ready to propagate updates to
// inputs, any intermediate views of the input created by the program will
// have been deallocated. This function also returns whether or not the base
// actually had any updates to apply.
bool apply_updates();
// Takes the current state of value_ and snapshots it, sending it as a pending
// update to the alias.
void commit_update();
// When any tensor is mutated, the tensor increments its alias's "generation".
// Separately, each tensor maintains its own "generation" counter, which is
// used to determine if it's up-to-date with its alias. The act of syncing a
// tensor will set a tensor's generation equal to its alias's generation.
bool is_up_to_date() const;
// Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
// describing the series of view ops that ran to generate the current tensor
// from the base tensor. This method is used by inplace-view ops like
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the
// tensor by replaying the views off of the alias.
void mutate_view_meta(at::functionalization::ViewMeta meta);
// The functionalization pass can be used to remove mutations.
// It does so by replacing any mutation op with it's corresponding
// out-of-place op, followed by a call to replace_(). e.g:
//
// a.add_(1)
//
// will turn into:
//
// tmp = a.add(1)
// a.replace_(tmp)
//
// replace_() swaps out the wrapped tensor, value_, with tmp.
void replace_(const Tensor& other);
// See Note[resize_() in functionalization pass]
void maybe_replace_storage(const Tensor& other);
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override;
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override;
~FunctionalTensorWrapper() override = default;
// FunctionalTensorWrapper overrides all custom size/stride function,
// so that if the inner tensor has a custom implementation
// we make sure to call that implementation.
at::IntArrayRef sizes_custom() const override;
at::IntArrayRef strides_custom() const override;
int64_t dim_custom() const override;
int64_t numel_custom() const override;
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
c10::SymIntArrayRef sym_sizes_custom() const override;
c10::SymIntArrayRef sym_strides_custom() const override;
private:
const char* tensorimpl_type_name() const override;
void set_constructor_metadata();
functionalization::FunctionalStorageImpl* functional_storage_impl() const;
// This is used to re-implement shallow_copy_and_detach for
// FunctionalTensorWrapper. The implementation is identical, but we just need
// to return a subclass instead of a plain TensorImpl.
// TODO: maybe it's possible to arrange for that to happen automatically
// without an override here?
template <typename VariableVersion>
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const;
// Note that value is not taken by reference: internally, the wrapper will
// change the value tensor that it points to over time.
Tensor value_;
int64_t level_;
size_t generation_ = 0;
std::vector<at::functionalization::ViewMeta> view_metas_;
};
// Utility functions for the functionalization pass.
namespace functionalization {
namespace impl {
TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
const Tensor& tensor) {
auto functional_impl =
static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
return functional_impl;
}
TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
TORCH_API bool isFunctionalTensor(const c10::optional<Tensor>& t);
TORCH_API bool isFunctionalTensor(
const c10::List<c10::optional<Tensor>>& t_list);
TORCH_API bool isFunctionalTensor(ITensorListRef list);
TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
TORCH_API c10::optional<Tensor> to_functional_tensor(
const c10::optional<Tensor>& tensor);
TORCH_API c10::List<c10::optional<Tensor>> to_functional_tensor(
const c10::List<c10::optional<Tensor>>& t_list);
TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
TORCH_API Tensor
from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
TORCH_API c10::optional<Tensor> from_functional_tensor(
const c10::optional<Tensor>& t,
bool assert_functional = true);
TORCH_API c10::List<c10::optional<Tensor>> from_functional_tensor(
const c10::List<c10::optional<Tensor>>& t_list);
TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
TORCH_API void sync(const at::Tensor& t);
TORCH_API void sync(const c10::optional<Tensor>& t);
TORCH_API void sync(const c10::List<c10::optional<Tensor>> t_list);
TORCH_API void sync(ITensorListRef t_list);
TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
TORCH_API void replace_(
const ITensorListRef functional_tensor,
ITensorListRef other);
TORCH_API void commit_update(const Tensor& functional_tensor);
TORCH_API void commit_update(ITensorListRef functional_tensor);
Tensor create_functional_tensor_with_view_meta(
const Tensor& view_to_wrap,
const Tensor& base,
functionalization::ViewMeta meta,
int64_t out_idx = 0);
std::vector<Tensor> create_functional_tensor_with_view_meta(
ITensorListRef view_to_wrap,
const Tensor& base,
functionalization::ViewMeta meta);
void mutate_view_meta(const Tensor& self, functionalization::ViewMeta meta);
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
void set_sizes_strides_offset(
const std::vector<Tensor>& outs,
const std::vector<Tensor>& meta_outs);
// ~~~~~ TLS used in functionalization ~~~~~
TORCH_API bool getFunctionalizationReapplyViewsTLS();
TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
class TORCH_API FunctionalizationReapplyViewsGuard {
public:
FunctionalizationReapplyViewsGuard(bool reapply_views) {
prev_ = getFunctionalizationReapplyViewsTLS();
setFunctionalizationReapplyViewsTLS(reapply_views);
}
~FunctionalizationReapplyViewsGuard() {
setFunctionalizationReapplyViewsTLS(prev_);
}
FunctionalizationReapplyViewsGuard(
const FunctionalizationReapplyViewsGuard&) = delete;
FunctionalizationReapplyViewsGuard operator=(
const FunctionalizationReapplyViewsGuard&) = delete;
FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
delete;
FunctionalizationReapplyViewsGuard operator=(
FunctionalizationReapplyViewsGuard&&) = delete;
private:
bool prev_;
};
} // namespace impl
// Helper function to call an out-of-place composite aten kernel that may use
// mutations / views internally, and functionalize them.
TORCH_API void functionalize_op_helper(
const c10::OperatorHandle& op,
torch::jit::Stack* stack);
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
struct _functionalize_aten_op final {};
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
static ReturnType call(
typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
using FuncType = ReturnType(
typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow(
(const char*)Op::name, (const char*)Op::overload_name)
.typed<FuncType>();
return c10::impl::BoxedKernelWrapper<FuncType>::call(
c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
op,
// BoxedKernelWrapper knows to ignore this keyset argument,
// because functionalize_op_helper doesn't take in a DispatchKeySet
c10::DispatchKeySet(),
args...);
}
};
template <class Op>
using functionalize_aten_op =
_functionalize_aten_op<Op, false, typename Op::schema>;
template <class Op>
using functionalize_aten_op_symint =
_functionalize_aten_op<Op, true, typename Op::schema>;
} // namespace functionalization
} // namespace at
|