|
|
#pragma once |
|
|
|
|
|
#include <cstdint> |
|
|
#include <tuple> |
|
|
#include <type_traits> |
|
|
#include <utility> |
|
|
|
|
|
#include <c10/util/ArrayRef.h> |
|
|
#include <ATen/core/List.h> |
|
|
|
|
|
namespace at { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename F> |
|
|
struct IterArgs { |
|
|
template <typename... Args> |
|
|
inline F& apply() { |
|
|
return self(); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
template <typename T, typename... Args> |
|
|
inline F& apply(T&& arg, Args&&... args) { |
|
|
self()(std::forward<T>(arg)); |
|
|
if (self().short_circuit()) { |
|
|
return self(); |
|
|
} else { |
|
|
return apply(std::forward<Args>(args)...); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
void operator()(c10::IListRef<T> args) { |
|
|
for (const auto& arg : args) { |
|
|
self()(arg); |
|
|
if (self().short_circuit()) |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
template <typename T> |
|
|
void operator()(at::ArrayRef<T> args) { |
|
|
for (const auto& arg : args) { |
|
|
self()(arg); |
|
|
if (self().short_circuit()) |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
template <typename T> |
|
|
void operator()(const torch::List<T>& args) { |
|
|
for (const auto& arg : args) { |
|
|
self()(arg); |
|
|
if (self().short_circuit()) |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
void operator()(const std::vector<T>& args) { |
|
|
self()(at::ArrayRef<T>{args}); |
|
|
} |
|
|
|
|
|
constexpr bool short_circuit() const { |
|
|
return false; |
|
|
} |
|
|
|
|
|
private: |
|
|
inline F& self() { |
|
|
return *static_cast<F*>(this); |
|
|
} |
|
|
}; |
|
|
|
|
|
} |
|
|
|