|
|
#pragma once
|
|
|
|
|
|
#include <ATen/functorch/Macros.h>
|
|
|
#include <ATen/core/dispatch/Dispatcher.h>
|
|
|
#include <c10/core/impl/LocalDispatchKeySet.h>
|
|
|
#include <optional>
|
|
|
#include <bitset>
|
|
|
#include <utility>
|
|
|
#include <variant>
|
|
|
|
|
|
#include <nlohmann/json.hpp>
|
|
|
|
|
|
namespace at::functorch {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enum class RandomnessType {
|
|
|
Error,
|
|
|
Same,
|
|
|
Different,
|
|
|
END
|
|
|
};
|
|
|
|
|
|
enum class TransformType {
|
|
|
Torch,
|
|
|
Vmap,
|
|
|
Grad,
|
|
|
Jvp,
|
|
|
Functionalize,
|
|
|
};
|
|
|
|
|
|
std::ostream& operator<<(std::ostream& os, const TransformType& t);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct VmapInterpreterMeta {
|
|
|
explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
|
|
|
batchSize_(std::move(batchSize)), randomness_(randomness) {}
|
|
|
|
|
|
c10::SymInt batchSize_;
|
|
|
RandomnessType randomness_;
|
|
|
|
|
|
VmapInterpreterMeta() = default;
|
|
|
VmapInterpreterMeta(const VmapInterpreterMeta&) = default;
|
|
|
VmapInterpreterMeta(VmapInterpreterMeta&&) = default;
|
|
|
VmapInterpreterMeta& operator=(const VmapInterpreterMeta&) = default;
|
|
|
VmapInterpreterMeta& operator=(VmapInterpreterMeta&&) = default;
|
|
|
~VmapInterpreterMeta() = default;
|
|
|
|
|
|
template <typename T>
|
|
|
friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) {
|
|
|
if (json_t.batchSize_.is_heap_allocated()) {
|
|
|
throw std::runtime_error("Serialization for heap-allocated SymInt is not implemented yet");
|
|
|
}
|
|
|
json_j["batchSize"] = json_t.batchSize_.as_int_unchecked();
|
|
|
json_j["randomness"] = static_cast<int64_t>(json_t.randomness_);
|
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
|
friend void from_json(const T& json_j, VmapInterpreterMeta& json_t) {
|
|
|
json_t.batchSize_ = c10::SymInt(SymInt::Unchecked::UNCHECKED, json_j["batchSize"]);
|
|
|
json_t.randomness_ = static_cast<RandomnessType>(json_j["randomness"]);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
struct GradInterpreterMeta {
|
|
|
explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
|
|
|
GradInterpreterMeta() = default;
|
|
|
GradInterpreterMeta(const GradInterpreterMeta&) = default;
|
|
|
GradInterpreterMeta(GradInterpreterMeta&&) = default;
|
|
|
GradInterpreterMeta& operator=(const GradInterpreterMeta&) = default;
|
|
|
GradInterpreterMeta& operator=(GradInterpreterMeta&&) = default;
|
|
|
~GradInterpreterMeta() = default;
|
|
|
|
|
|
bool prevGradMode_;
|
|
|
template <typename T>
|
|
|
friend void to_json(T& json_j, const GradInterpreterMeta& json_t) {
|
|
|
json_j["prevGradMode"] = json_t.prevGradMode_;
|
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
|
friend void from_json(const T& json_j, GradInterpreterMeta& json_t) {
|
|
|
json_t.prevGradMode_ = json_j["prevGradMode"];
|
|
|
}
|
|
|
};
|
|
|
|
|
|
struct JvpInterpreterMeta {
|
|
|
explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
|
|
|
JvpInterpreterMeta() = default;
|
|
|
JvpInterpreterMeta(const JvpInterpreterMeta&) = default;
|
|
|
JvpInterpreterMeta(JvpInterpreterMeta&&) = default;
|
|
|
JvpInterpreterMeta& operator=(const JvpInterpreterMeta&) = default;
|
|
|
JvpInterpreterMeta& operator=(JvpInterpreterMeta&&) = default;
|
|
|
~JvpInterpreterMeta() = default;
|
|
|
|
|
|
bool prevFwdGradMode_;
|
|
|
template <typename T>
|
|
|
friend void to_json(T& json_j, const JvpInterpreterMeta& json_t) {
|
|
|
json_j["prevFwdGradMode"] = json_t.prevFwdGradMode_;
|
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
|
friend void from_json(const T& json_j, JvpInterpreterMeta& json_t) {
|
|
|
json_t.prevFwdGradMode_ = json_j["prevFwdGradMode"];
|
|
|
}
|
|
|
};
|
|
|
|
|
|
struct FunctionalizeInterpreterMeta {
|
|
|
explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
|
|
|
functionalizeAddBackViews_(functionalizeAddBackViews) {}
|
|
|
FunctionalizeInterpreterMeta() = default;
|
|
|
FunctionalizeInterpreterMeta(const FunctionalizeInterpreterMeta&) = default;
|
|
|
FunctionalizeInterpreterMeta(FunctionalizeInterpreterMeta&&) = default;
|
|
|
FunctionalizeInterpreterMeta& operator=(const FunctionalizeInterpreterMeta&) = default;
|
|
|
FunctionalizeInterpreterMeta& operator=(FunctionalizeInterpreterMeta&&) = default;
|
|
|
~FunctionalizeInterpreterMeta() = default;
|
|
|
|
|
|
bool functionalizeAddBackViews_;
|
|
|
template <typename T>
|
|
|
friend void to_json(T& json_j, const FunctionalizeInterpreterMeta& json_t) {
|
|
|
json_j["functionalizeAddBackViews"] = json_t.functionalizeAddBackViews_;
|
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
|
friend void from_json(const T& json_j, FunctionalizeInterpreterMeta& json_t) {
|
|
|
json_t.functionalizeAddBackViews_ = json_j["functionalizeAddBackViews"];
|
|
|
}
|
|
|
};
|
|
|
|
|
|
typedef std::variant<
|
|
|
int64_t,
|
|
|
GradInterpreterMeta,
|
|
|
JvpInterpreterMeta,
|
|
|
VmapInterpreterMeta,
|
|
|
FunctionalizeInterpreterMeta
|
|
|
> InterpreterMeta;
|
|
|
|
|
|
|
|
|
struct Interpreter {
|
|
|
|
|
|
static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) {
|
|
|
return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness));
|
|
|
}
|
|
|
static Interpreter Grad(int64_t level, bool prevGradMode) {
|
|
|
return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode));
|
|
|
}
|
|
|
static Interpreter Jvp(int64_t level, bool prevFwdGradMode) {
|
|
|
return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode));
|
|
|
}
|
|
|
static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) {
|
|
|
return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews));
|
|
|
}
|
|
|
|
|
|
|
|
|
TransformType key() const { return type_; }
|
|
|
int64_t level() const { return level_; }
|
|
|
const InterpreterMeta& meta() const { return meta_; }
|
|
|
|
|
|
void process(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
|
|
void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
|
|
|
|
|
void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
|
|
|
TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
|
|
|
savedLocalDispatchKeySet_ = keyset;
|
|
|
}
|
|
|
void clearSavedLocalDispatchKeySet() {
|
|
|
TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
|
|
|
savedLocalDispatchKeySet_ = std::nullopt;
|
|
|
}
|
|
|
c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
|
|
|
TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
|
|
|
return *savedLocalDispatchKeySet_;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool is_alive() const {
|
|
|
return *is_alive_;
|
|
|
}
|
|
|
const std::shared_ptr<bool>& is_alive_ptr() const {
|
|
|
return is_alive_;
|
|
|
}
|
|
|
void set_is_alive(bool alive) {
|
|
|
*is_alive_ = alive;
|
|
|
}
|
|
|
|
|
|
|
|
|
explicit Interpreter() = default;
|
|
|
|
|
|
template <typename T>
|
|
|
friend void to_json(T& json_j, const Interpreter& json_t) {
|
|
|
json_j["type"] = static_cast<int64_t>(json_t.type_);
|
|
|
json_j["level"] = json_t.level_;
|
|
|
if (json_t.savedLocalDispatchKeySet_) {
|
|
|
json_j["savedLocalDispatchKeySet"] = {
|
|
|
{"included", json_t.savedLocalDispatchKeySet_->included_.raw_repr()},
|
|
|
{"excluded", json_t.savedLocalDispatchKeySet_->excluded_.raw_repr()}
|
|
|
};
|
|
|
} else {
|
|
|
json_j["savedLocalDispatchKeySet"] = nlohmann::json();
|
|
|
}
|
|
|
json_j["is_alive"] = *json_t.is_alive_;
|
|
|
std::visit([&](auto&& arg) {
|
|
|
using V = std::decay_t<decltype(arg)>;
|
|
|
if constexpr (std::is_same_v<V, int64_t>) {
|
|
|
json_j["meta"] = {{"Torch", arg}};
|
|
|
} else if constexpr (std::is_same_v<V, GradInterpreterMeta>) {
|
|
|
json_j["meta"] = {{"Grad", arg}};
|
|
|
} else if constexpr (std::is_same_v<V, JvpInterpreterMeta>) {
|
|
|
json_j["meta"] = {{"Jvp", arg}};
|
|
|
} else if constexpr (std::is_same_v<V, VmapInterpreterMeta>) {
|
|
|
json_j["meta"] = {{"Vmap", arg}};
|
|
|
} else if constexpr (std::is_same_v<V, FunctionalizeInterpreterMeta>) {
|
|
|
json_j["meta"] = {{"Functionalize", arg}};
|
|
|
} else {
|
|
|
static_assert(false && sizeof(V), "unknown variant case");
|
|
|
}
|
|
|
}, json_t.meta_);
|
|
|
}
|
|
|
|
|
|
template <typename T>
|
|
|
friend void from_json(const T& json_j, Interpreter& json_t) {
|
|
|
json_t.type_ = static_cast<TransformType>(json_j["type"]);
|
|
|
json_t.level_ = json_j["level"];
|
|
|
auto savedLocalDispatchKeySet = json_j["savedLocalDispatchKeySet"];
|
|
|
if (savedLocalDispatchKeySet.is_null()) {
|
|
|
json_t.savedLocalDispatchKeySet_ = std::nullopt;
|
|
|
} else {
|
|
|
c10::impl::PODLocalDispatchKeySet pod;
|
|
|
pod.set_included(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["included"].template get<uint64_t>()));
|
|
|
pod.set_excluded(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["excluded"].template get<uint64_t>()));
|
|
|
json_t.savedLocalDispatchKeySet_ = c10::impl::LocalDispatchKeySet(pod);
|
|
|
}
|
|
|
json_t.is_alive_ = std::make_shared<bool>(json_j["is_alive"]);
|
|
|
auto meta = json_j["meta"];
|
|
|
if (meta.contains("Torch")) {
|
|
|
json_t.meta_.emplace<int64_t>(meta["Torch"].template get<int64_t>());
|
|
|
} else if (meta.contains("Grad")) {
|
|
|
json_t.meta_.emplace<GradInterpreterMeta>(meta["Grad"].template get<GradInterpreterMeta>());
|
|
|
} else if (meta.contains("Jvp")) {
|
|
|
json_t.meta_.emplace<JvpInterpreterMeta>(meta["Jvp"].template get<JvpInterpreterMeta>());
|
|
|
} else if (meta.contains("Vmap")) {
|
|
|
json_t.meta_.emplace<VmapInterpreterMeta>(meta["Vmap"].template get<VmapInterpreterMeta>());
|
|
|
} else if (meta.contains("Functionalize")) {
|
|
|
json_t.meta_.emplace<FunctionalizeInterpreterMeta>(meta["Functionalize"].template get<FunctionalizeInterpreterMeta>());
|
|
|
} else {
|
|
|
throw std::runtime_error("unknown interpreter metadata type");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
std::string serialize() const {
|
|
|
return nlohmann::json(*this).dump();
|
|
|
}
|
|
|
|
|
|
static Interpreter deserialize(const std::string& serialized) {
|
|
|
return nlohmann::json::parse(serialized).get<Interpreter>();
|
|
|
}
|
|
|
|
|
|
private:
|
|
|
explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
|
|
|
type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {}
|
|
|
|
|
|
|
|
|
TransformType type_{};
|
|
|
int64_t level_{};
|
|
|
std::optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
|
|
|
std::shared_ptr<bool> is_alive_;
|
|
|
InterpreterMeta meta_;
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
|
|
|
std::function<Tensor(const Tensor&)> func);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end,
|
|
|
const std::bitset<64> use_flag_relative, const std::function<Tensor(const Tensor&, bool)>& func);
|
|
|
|
|
|
std::vector<int64_t> findUnwrappedInputs(std::vector<IValue>& args, int64_t begin, int64_t end);
|
|
|
|
|
|
DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key);
|
|
|
|
|
|
void setup_dispatch_key_tls(TransformType key, DispatchKeySet include);
|
|
|
|
|
|
void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
|
|
|
|
|
}
|
|
|
|