|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
#include <ATen/ATen.h>
|
|
|
#include <ATen/core/op_registration/op_registration.h>
|
|
|
#include <torch/library.h>
|
|
|
|
|
|
namespace at::functorch {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
|
|
void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
|
|
|
|
|
void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_API bool isVmapFallbackWarningEnabled();
|
|
|
TORCH_API void setVmapFallbackWarningEnabled(bool enabled);
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_API bool isVmapFallbackEnabled();
|
|
|
TORCH_API void setVmapFallbackEnabled(bool enabled);
|
|
|
|
|
|
template <typename A> A vector_to_result(const std::vector<IValue>& buffer) {
|
|
|
return buffer[0].to<A>();
|
|
|
}
|
|
|
template <typename A, typename B> std::tuple<A, B> vector_to_result(const std::vector<IValue>& buffer) {
|
|
|
return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>());
|
|
|
}
|
|
|
template <typename A, typename B, typename C> std::tuple<A, B, C> vector_to_result(const std::vector<IValue>& buffer) {
|
|
|
return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>(), buffer[2].to<B>());
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Ret>
|
|
|
Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
|
|
|
std::vector<IValue> stack(args.begin(), args.end());
|
|
|
batchedTensorForLoopFallback(op, &stack);
|
|
|
return vector_to_result<Ret>(stack);
|
|
|
}
|
|
|
|
|
|
template <typename A, typename B>
|
|
|
std::tuple<A, B> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
|
|
|
std::vector<IValue> stack(args.begin(), args.end());
|
|
|
batchedTensorForLoopFallback(op, &stack);
|
|
|
return vector_to_result<A, B>(stack);
|
|
|
}
|
|
|
|
|
|
template <typename A, typename B, typename C>
|
|
|
std::tuple<A, B, C> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
|
|
|
std::vector<IValue> stack(args.begin(), args.end());
|
|
|
batchedTensorForLoopFallback(op, &stack);
|
|
|
return vector_to_result<A, B, C>(stack);
|
|
|
}
|
|
|
|
|
|
|
|
|
}
|
|
|
|