|
|
#pragma once |
|
|
|
|
|
#include <c10/macros/Export.h> |
|
|
#include <ATen/core/Tensor.h> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace torch { namespace autograd { |
|
|
|
|
|
struct Node; |
|
|
|
|
|
}} |
|
|
|
|
|
namespace at { |
|
|
namespace impl { |
|
|
|
|
|
struct TORCH_API VariableHooksInterface { |
|
|
virtual ~VariableHooksInterface() = default; |
|
|
virtual TensorBase tensor_data(const TensorBase&) const = 0; |
|
|
virtual TensorBase variable_data(const TensorBase&) const = 0; |
|
|
virtual const std::shared_ptr<torch::autograd::Node>& grad_fn(const TensorBase&) const = 0; |
|
|
virtual unsigned _register_hook( |
|
|
const TensorBase&, |
|
|
std::function<TensorBase(const TensorBase&)> hook) const = 0; |
|
|
virtual void remove_hook(const TensorBase&, unsigned pos) const = 0; |
|
|
virtual bool is_view(const TensorBase&) const = 0; |
|
|
virtual const TensorBase& base(const TensorBase&) const = 0; |
|
|
virtual const std::string& name(const TensorBase&) const = 0; |
|
|
virtual bool is_leaf(const TensorBase&) const = 0; |
|
|
virtual int64_t output_nr(const TensorBase&) const = 0; |
|
|
virtual void set_data(const TensorBase&, const TensorBase&) const = 0; |
|
|
virtual TensorBase data(const TensorBase&) const = 0; |
|
|
virtual int64_t _version(const TensorBase&) const = 0; |
|
|
virtual void retain_grad(const TensorBase&) const = 0; |
|
|
virtual bool retains_grad(const TensorBase&) const = 0; |
|
|
virtual void _backward(const Tensor&, TensorList, const c10::optional<Tensor>&, c10::optional<bool>, bool) const = 0; |
|
|
virtual void requires_grad_(const TensorBase&, bool) const = 0; |
|
|
}; |
|
|
|
|
|
TORCH_API void SetVariableHooks(VariableHooksInterface* hooks); |
|
|
TORCH_API VariableHooksInterface* GetVariableHooks(); |
|
|
|
|
|
struct TORCH_API VariableHooksRegisterer { |
|
|
explicit VariableHooksRegisterer(VariableHooksInterface* hooks) { |
|
|
SetVariableHooks(hooks); |
|
|
} |
|
|
}; |
|
|
|
|
|
}} |
|
|
|