|
|
#pragma once |
|
|
|
|
|
#include <c10/core/SymIntNodeImpl.h> |
|
|
#include <c10/macros/Macros.h> |
|
|
#include <c10/util/Exception.h> |
|
|
#include <c10/util/intrusive_ptr.h> |
|
|
|
|
|
#include <memory> |
|
|
#include <numeric> |
|
|
|
|
|
namespace c10 { |
|
|
|
|
|
class SymFloat; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef C10_MOBILE |
|
|
#define SKIP_IS_SYMBOLIC_ON_MOBILE(_) \ |
|
|
do { \ |
|
|
} while (0) |
|
|
#else |
|
|
#define SKIP_IS_SYMBOLIC_ON_MOBILE(X) TORCH_CHECK(X) |
|
|
#endif |
|
|
|
|
|
class C10_API SymInt { |
|
|
public: |
|
|
enum Unchecked { |
|
|
UNCHECKED, |
|
|
}; |
|
|
|
|
|
SymInt(int64_t d) : data_(d) { |
|
|
SKIP_IS_SYMBOLIC_ON_MOBILE(!is_symbolic()); |
|
|
}; |
|
|
SymInt() : data_(0) {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SymInt(Unchecked, int64_t d) : data_(d) {} |
|
|
|
|
|
|
|
|
|
|
|
SymInt(const SymInt& s) : data_(0) { |
|
|
if (s.is_symbolic()) { |
|
|
*this = SymInt::toSymInt(s.toSymIntNodeImpl()); |
|
|
} else { |
|
|
data_ = s.data_; |
|
|
} |
|
|
} |
|
|
SymInt(SymInt&& s) : data_(s.data_) { |
|
|
s.data_ = 0; |
|
|
} |
|
|
|
|
|
SymInt& operator=(const SymInt& s) { |
|
|
if (this != &s) { |
|
|
if (s.is_symbolic()) { |
|
|
*this = SymInt::toSymInt(s.toSymIntNodeImpl()); |
|
|
} else { |
|
|
data_ = s.data_; |
|
|
} |
|
|
} |
|
|
return *this; |
|
|
} |
|
|
SymInt& operator=(SymInt&& s) { |
|
|
if (this != &s) { |
|
|
release_(); |
|
|
data_ = s.data_; |
|
|
if (s.is_symbolic()) |
|
|
s.data_ = 0; |
|
|
}; |
|
|
return *this; |
|
|
} |
|
|
|
|
|
SymInt clone() const { |
|
|
#ifndef C10_MOBILE |
|
|
if (is_symbolic()) { |
|
|
return toSymIntNodeImplUnowned()->clone()->toSymInt(); |
|
|
} |
|
|
#else |
|
|
TORCH_INTERNAL_ASSERT(!is_symbolic()); |
|
|
#endif |
|
|
return *this; |
|
|
} |
|
|
|
|
|
#ifndef C10_MOBILE |
|
|
SymIntNodeImpl* toSymIntNodeImplUnowned() const { |
|
|
uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK; |
|
|
uint64_t sign_bit_mask = 1ULL << (62 - 1); |
|
|
|
|
|
uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask; |
|
|
return static_cast<SymIntNodeImpl*>( |
|
|
reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits))); |
|
|
} |
|
|
|
|
|
void release_() { |
|
|
if (is_symbolic()) { |
|
|
SymIntNode::reclaim(toSymIntNodeImplUnowned()); |
|
|
} |
|
|
} |
|
|
|
|
|
SymIntNodeImpl* release() && { |
|
|
TORCH_INTERNAL_ASSERT(is_symbolic()); |
|
|
auto* r = toSymIntNodeImplUnowned(); |
|
|
data_ = 0; |
|
|
return r; |
|
|
} |
|
|
#else |
|
|
void release_() {} |
|
|
|
|
|
SymIntNodeImpl* release() && { |
|
|
TORCH_INTERNAL_ASSERT(false); |
|
|
} |
|
|
#endif |
|
|
|
|
|
SymIntNode toSymIntNodeImpl() const; |
|
|
static c10::SymInt toSymInt(SymIntNode sin); |
|
|
|
|
|
~SymInt() { |
|
|
release_(); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int64_t expect_int() const { |
|
|
SKIP_IS_SYMBOLIC_ON_MOBILE(!is_symbolic()); |
|
|
return data_; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int64_t guard_int(const char* file, int64_t line) const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
C10_ALWAYS_INLINE bool is_symbolic() const { |
|
|
#ifdef C10_MOBILE |
|
|
return false; |
|
|
#else |
|
|
return (MASK & static_cast<uint64_t>(this->data_)) == IS_SYM; |
|
|
#endif |
|
|
} |
|
|
|
|
|
SymInt operator+(SymInt sci) const; |
|
|
SymInt operator-(SymInt sci) const; |
|
|
SymInt operator*(SymInt sci) const; |
|
|
SymInt operator/(SymInt sci) const; |
|
|
SymInt operator%(SymInt sci) const; |
|
|
bool operator==(SymInt sci) const; |
|
|
bool operator!=(SymInt p2) const; |
|
|
bool operator<(SymInt sci) const; |
|
|
bool operator<=(SymInt sci) const; |
|
|
bool operator>(SymInt sci) const; |
|
|
bool operator>=(SymInt sci) const; |
|
|
void operator*=(SymInt sci); |
|
|
void operator+=(SymInt sci); |
|
|
|
|
|
SymInt operator*(int64_t sci) const; |
|
|
bool operator<(int64_t sci) const; |
|
|
bool operator==(int64_t sci) const; |
|
|
bool operator!=(int64_t sci) const; |
|
|
bool operator<=(int64_t sci) const; |
|
|
bool operator>(int64_t sci) const; |
|
|
bool operator>=(int64_t sci) const; |
|
|
|
|
|
operator SymFloat() const; |
|
|
|
|
|
int64_t as_int_unchecked() const { |
|
|
return data_; |
|
|
} |
|
|
|
|
|
|
|
|
static bool check_range(int64_t i) { |
|
|
return i > MIN_INT; |
|
|
} |
|
|
|
|
|
private: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static constexpr uint64_t MASK = 1ULL << 63 | 1ULL << 62; |
|
|
static constexpr uint64_t IS_SYM = 1ULL << 63; |
|
|
|
|
|
|
|
|
|
|
|
static constexpr uint64_t MAX_SYM_IDX = 1ULL << 62; |
|
|
|
|
|
|
|
|
static constexpr int64_t MIN_INT = -1LL & static_cast<int64_t>(~(1ULL << 62)); |
|
|
int64_t data_; |
|
|
}; |
|
|
|
|
|
#undef SKIP_IS_SYMBOLIC_ON_MOBILE |
|
|
|
|
|
|
|
|
template < |
|
|
typename C, |
|
|
typename std::enable_if< |
|
|
std::is_same<typename C::value_type, c10::SymInt>::value, |
|
|
int>::type = 0> |
|
|
inline c10::SymInt multiply_integers(const C& container) { |
|
|
return std::accumulate( |
|
|
container.begin(), |
|
|
container.end(), |
|
|
c10::SymInt(1), |
|
|
[](c10::SymInt a, c10::SymInt b) { return a * b; }); |
|
|
} |
|
|
|
|
|
C10_API std::ostream& operator<<(std::ostream& os, SymInt s); |
|
|
} |
|
|
|